mirror of
https://github.com/langgenius/dify.git
synced 2026-02-25 02:35:12 +00:00
feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["GraphEngine"]
|
||||
|
||||
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
In-memory implementation of CommandChannel for local/testing scenarios.
|
||||
|
||||
This implementation uses a thread-safe queue for command communication
|
||||
within a single process. Each instance handles commands for one workflow execution.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
Each instance is dedicated to a single GraphEngine/workflow execution.
|
||||
Suitable for local development, testing, and single-instance deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory channel with a single queue."""
|
||||
self._queue: Queue[GraphEngineCommand] = Queue()
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from the queue.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the queue)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Drain all available commands from the queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
command = self._queue.get_nowait()
|
||||
commands.append(command)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to this channel's queue.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
self._queue.put(command)
|
||||
114
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
114
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Redis-based implementation of CommandChannel for distributed scenarios.
|
||||
|
||||
This implementation uses Redis lists for command queuing, supporting
|
||||
multi-instance deployments and cross-server communication.
|
||||
Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
||||
Each instance uses a unique Redis key for its command queue.
|
||||
Commands are JSON-serialized for transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Redis channel.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
channel_key: Unique key for this channel's command queue
|
||||
command_ttl: TTL for command keys in seconds (default: 3600)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from Redis.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
with self._redis.pipeline() as pipe:
|
||||
# Get all commands and clear the list atomically
|
||||
pipe.lrange(self._key, 0, -1)
|
||||
pipe.delete(self._key)
|
||||
results = pipe.execute()
|
||||
|
||||
# Parse commands from JSON
|
||||
if results[0]:
|
||||
for command_json in results[0]:
|
||||
try:
|
||||
command_data = json.loads(command_json)
|
||||
command = self._deserialize_command(command_data)
|
||||
if command:
|
||||
commands.append(command)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Skip invalid commands
|
||||
continue
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to Redis.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
command_json = json.dumps(command.model_dump())
|
||||
|
||||
# Push to list and set expiry
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
Args:
|
||||
data: Command data dictionary
|
||||
|
||||
Returns:
|
||||
Deserialized command or None if invalid
|
||||
"""
|
||||
command_type_value = data.get("command_type")
|
||||
if not isinstance(command_type_value, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Command processing subsystem for graph engine.
|
||||
|
||||
This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
from ..protocols.command_channel import CommandChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandHandler(Protocol):
|
||||
"""Protocol for command handlers."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
This polls the command channel and dispatches commands to
|
||||
appropriate handlers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_channel: CommandChannel,
|
||||
graph_execution: GraphExecution,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the command processor.
|
||||
|
||||
Args:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
"""
|
||||
Register a handler for a command type.
|
||||
|
||||
Args:
|
||||
command_type: Type of command to handle
|
||||
handler: Handler for the command
|
||||
"""
|
||||
self._handlers[command_type] = handler
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
logger.warning("Error processing commands: %s", e)
|
||||
|
||||
def _handle_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Handle a single command.
|
||||
|
||||
Args:
|
||||
command: The command to handle
|
||||
"""
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
@@ -1,25 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,25 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
@@ -1,27 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
_, _, final_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions,
|
||||
operator="and",
|
||||
)
|
||||
|
||||
return final_result
|
||||
@@ -1,25 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
14
api/core/workflow/graph_engine/domain/__init__.py
Normal file
14
api/core/workflow/graph_engine/domain/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Domain models for graph engine.
|
||||
|
||||
This package contains the core domain entities, value objects, and aggregates
|
||||
that represent the business concepts of workflow graph execution.
|
||||
"""
|
||||
|
||||
from .graph_execution import GraphExecution
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
__all__ = [
|
||||
"GraphExecution",
|
||||
"NodeExecution",
|
||||
]
|
||||
207
api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
207
api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
class GraphExecutionErrorState(BaseModel):
|
||||
"""Serializable representation of an execution error."""
|
||||
|
||||
module: str = Field(description="Module containing the exception class")
|
||||
qualname: str = Field(description="Qualified name of the exception class")
|
||||
message: str | None = Field(default=None, description="Exception message string")
|
||||
|
||||
|
||||
class NodeExecutionState(BaseModel):
|
||||
"""Serializable representation of a node execution entity."""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||
retry_count: int = Field(default=0)
|
||||
execution_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Pydantic model describing serialized GraphExecution state."""
|
||||
|
||||
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||
version: str = Field(default="1.0")
|
||||
workflow_id: str
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
"""Convert an exception into its serializable representation."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
|
||||
return GraphExecutionErrorState(
|
||||
module=error.__class__.__module__,
|
||||
qualname=error.__class__.__qualname__,
|
||||
message=str(error),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||
"""Locate an exception class from its module and qualified name."""
|
||||
|
||||
module = import_module(module_name)
|
||||
attr: object = module
|
||||
for part in qualname.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||
return attr
|
||||
|
||||
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||
|
||||
|
||||
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||
"""Reconstruct an exception instance from serialized data."""
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||
if state.message is None:
|
||||
return exception_class()
|
||||
return exception_class(state.message)
|
||||
except Exception:
|
||||
# Fallback to RuntimeError when reconstruction fails
|
||||
if state.message is None:
|
||||
return RuntimeError(state.qualname)
|
||||
return RuntimeError(state.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
Aggregate root for graph execution.
|
||||
|
||||
This manages the overall execution state of a workflow graph,
|
||||
coordinating between multiple node executions.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
if self.started:
|
||||
raise RuntimeError("Graph execution already started")
|
||||
self.started = True
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark the graph execution as completed."""
|
||||
if not self.started:
|
||||
raise RuntimeError("Cannot complete execution that hasn't started")
|
||||
if self.completed:
|
||||
raise RuntimeError("Graph execution already completed")
|
||||
self.completed = True
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort the graph execution."""
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
self.completed = True
|
||||
|
||||
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
|
||||
"""Get or create a node execution entity."""
|
||||
if node_id not in self.node_executions:
|
||||
self.node_executions[node_id] = NodeExecution(node_id=node_id)
|
||||
return self.node_executions[node_id]
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
"""Check if the execution has encountered an error."""
|
||||
return self.error is not None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str | None:
|
||||
"""Get the error message if an error exists."""
|
||||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the aggregate state into a JSON string."""
|
||||
|
||||
node_states = [
|
||||
NodeExecutionState(
|
||||
node_id=node_id,
|
||||
state=node_execution.state,
|
||||
retry_count=node_execution.retry_count,
|
||||
execution_id=node_execution.execution_id,
|
||||
error=node_execution.error,
|
||||
)
|
||||
for node_id, node_execution in sorted(self.node_executions.items())
|
||||
]
|
||||
|
||||
state = GraphExecutionState(
|
||||
workflow_id=self.workflow_id,
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore aggregate state from a serialized JSON string."""
|
||||
|
||||
state = GraphExecutionState.model_validate_json(data)
|
||||
|
||||
if state.type != "GraphExecution":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
if self.workflow_id != state.workflow_id:
|
||||
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
state=item.state,
|
||||
retry_count=item.retry_count,
|
||||
execution_id=item.execution_id,
|
||||
error=item.error,
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
45
api/core/workflow/graph_engine/domain/node_execution.py
Normal file
45
api/core/workflow/graph_engine/domain/node_execution.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
NodeExecution entity representing a node's execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeExecution:
|
||||
"""
|
||||
Entity representing the execution state of a single node.
|
||||
|
||||
This is a mutable entity that tracks the runtime state of a node
|
||||
during graph execution.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.execution_id = execution_id
|
||||
|
||||
def mark_taken(self) -> None:
|
||||
"""Mark the node as successfully completed."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.error = None
|
||||
|
||||
def mark_failed(self, error: str) -> None:
|
||||
"""Mark the node as failed with an error."""
|
||||
self.error = error
|
||||
|
||||
def mark_skipped(self) -> None:
|
||||
"""Mark the node as skipped."""
|
||||
self.state = NodeState.SKIPPED
|
||||
|
||||
def increment_retry(self) -> None:
|
||||
"""Increment the retry count for this node."""
|
||||
self.retry_count += 1
|
||||
@@ -1,6 +0,0 @@
|
||||
from .graph import Graph
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .runtime_route_state import RuntimeRouteState
|
||||
|
||||
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
|
||||
33
api/core/workflow/graph_engine/entities/commands.py
Normal file
33
api/core/workflow/graph_engine/entities/commands.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
GraphEngine command entities for external control.
|
||||
|
||||
This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
RESUME = "resume"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
@@ -1,277 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: dict[str, Any] | None = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: str | None = None
|
||||
"""predecessor node id"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration node parallel mode run id"""
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] | None = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInLoopFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Any | None = None
|
||||
duration: float | None = None
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
iteration_duration_map: dict[str, float] | None = None
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Loop Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseLoopEvent(GraphEngineEvent):
|
||||
loop_id: str = Field(..., description="loop node execution id")
|
||||
loop_node_id: str = Field(..., description="loop node id")
|
||||
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
|
||||
loop_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""loop run in parallel mode run id"""
|
||||
|
||||
|
||||
class LoopRunStartedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class LoopRunNextEvent(BaseLoopEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Any | None = None
|
||||
duration: float | None = None
|
||||
|
||||
|
||||
class LoopRunSucceededEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
loop_duration_map: dict[str, float] | None = None
|
||||
|
||||
|
||||
class LoopRunFailedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Agent Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseAgentEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class AgentLogEvent(BaseAgentEvent):
|
||||
id: str = Field(..., description="id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Mapping[str, Any] | None = Field(default=None, description="metadata")
|
||||
node_id: str = Field(..., description="agent node id")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
|
||||
@@ -1,674 +0,0 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: RunCondition | None = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: str | None = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=dict, description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
|
||||
end_stream_param: EndStreamParam = Field(..., description="end stream param")
|
||||
|
||||
@classmethod
|
||||
def init(cls, graph_config: Mapping[str, Any], root_node_id: str | None = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get("edges")
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
# node configs
|
||||
node_configs = graph_config.get("nodes")
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
fail_branch_source_node_id = [
|
||||
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
|
||||
]
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get("source")
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get("target")
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get("sourceHandle"):
|
||||
if (
|
||||
edge_config.get("source") in fail_branch_source_node_id
|
||||
and edge_config.get("sourceHandle") != "fail-branch"
|
||||
):
|
||||
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
|
||||
elif edge_config.get("sourceHandle") != "source":
|
||||
run_condition = RunCondition(
|
||||
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next(
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str):
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]):
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(
|
||||
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
|
||||
)
|
||||
|
||||
new_route = route.copy()
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: GraphParallel | None = None,
|
||||
):
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = defaultdict(list)
|
||||
condition_edge_mappings = defaultdict(list)
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
||||
|
||||
condition_parallels = {}
|
||||
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
|
||||
# any target node id in node_parallel_mapping
|
||||
parallel = None
|
||||
if condition_parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
condition_parallels[condition_hash] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=condition_parallel_branch_node_ids,
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(
|
||||
node_parallel_mapping.get(target_node_id)
|
||||
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
|
||||
)
|
||||
or (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and target_node_id == parent_parallel.end_to_node_id
|
||||
)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and parallel.end_to_node_id == parent_parallel.end_to_node_id
|
||||
):
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
if condition_edge_mappings:
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
for graph_edge in graph_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=condition_parallels.get(condition_hash),
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
else:
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=parallel,
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
else:
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=parallel,
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_current_parallel(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
graph_edge: GraphEdge,
|
||||
parallel: GraphParallel | None = None,
|
||||
parent_parallel: GraphParallel | None = None,
|
||||
) -> GraphParallel | None:
|
||||
"""
|
||||
Get current parallel
|
||||
"""
|
||||
current_parallel = None
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (
|
||||
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
|
||||
):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if parent_parallel_parent_parallel and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (
|
||||
parent_parallel_parent_parallel.end_to_node_id
|
||||
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
|
||||
return current_parallel
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1,
|
||||
):
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(
|
||||
cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str,
|
||||
):
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id],
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids,
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (
|
||||
node_id2,
|
||||
node_id,
|
||||
) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id,
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(
|
||||
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
|
||||
):
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(
|
||||
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id in routes_node_ids:
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
expected_branch_set = set(routes_node_ids.keys())
|
||||
for _, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == expected_branch_set:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,21 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
@@ -1,21 +0,0 @@
|
||||
import hashlib
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: str | None = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: list[Condition] | None = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
@@ -1,117 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(StrEnum):
|
||||
RUNNING = auto()
|
||||
SUCCESS = auto()
|
||||
FAILED = auto()
|
||||
PAUSED = auto()
|
||||
EXCEPTION = auto()
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: datetime | None = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: datetime | None = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: str | None = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: str | None = None
|
||||
"""paused by"""
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult):
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in {
|
||||
RouteNodeState.Status.SUCCESS,
|
||||
RouteNodeState.Status.FAILED,
|
||||
RouteNodeState.Status.EXCEPTION,
|
||||
}:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
self.status = RouteNodeState.Status.EXCEPTION
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = naive_utc_now()
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str):
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [
|
||||
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
|
||||
]
|
||||
211
api/core/workflow/graph_engine/error_handler.py
Normal file
211
api/core/workflow/graph_engine/error_handler.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy as ErrorStrategyEnum,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .domain import GraphExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
||||
This acts as a facade for the various error strategies,
|
||||
selecting and applying the appropriate strategy based on
|
||||
node configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
||||
"""
|
||||
Initialize the error handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_execution = graph_execution
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
Selects and applies the appropriate error strategy based on
|
||||
the node's configuration.
|
||||
|
||||
Args:
|
||||
event: The node failure event
|
||||
|
||||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self._handle_retry(event, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
|
||||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
match strategy:
|
||||
case None:
|
||||
return self._handle_abort(event)
|
||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self._handle_fail_branch(event)
|
||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self._handle_default_value(event)
|
||||
|
||||
def _handle_abort(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
This is the default strategy when no other strategy is specified.
|
||||
It stops the entire graph execution when a node fails.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
# Return None to signal that execution should stop
|
||||
|
||||
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
This strategy re-attempts node execution up to a configured
|
||||
maximum number of retries with configurable intervals.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
retry_count: Current retry attempt count
|
||||
|
||||
Returns:
|
||||
NodeRunRetryEvent if retry should occur, None otherwise
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
# Check if we've exceeded max retries
|
||||
if not node.retry or retry_count >= node.retry_config.max_retries:
|
||||
return None
|
||||
|
||||
# Wait for retry interval
|
||||
time.sleep(node.retry_config.retry_interval_seconds)
|
||||
|
||||
# Create retry event
|
||||
return NodeRunRetryEvent(
|
||||
id=event.id,
|
||||
node_title=node.title,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_run_result=event.node_run_result,
|
||||
start_at=event.start_at,
|
||||
error=event.error,
|
||||
retry_index=retry_count + 1,
|
||||
)
|
||||
|
||||
def _handle_fail_branch(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
This strategy converts failures to exceptions and routes execution
|
||||
through a designated fail-branch edge.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle="fail-branch",
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_default_value(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
This strategy allows nodes to fail gracefully by providing
|
||||
predefined default output values.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
**node.default_value_dict,
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
14
api/core/workflow/graph_engine/event_management/__init__.py
Normal file
14
api/core/workflow/graph_engine/event_management/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Event management subsystem for graph engine.
|
||||
|
||||
This package handles event routing, collection, and emission for
|
||||
workflow graph execution events.
|
||||
"""
|
||||
|
||||
from .event_handlers import EventHandler
|
||||
from .event_manager import EventManager
|
||||
|
||||
__all__ = [
|
||||
"EventHandler",
|
||||
"EventManager",
|
||||
]
|
||||
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handler import ErrorHandler
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..graph_traversal import EdgeProcessor
|
||||
from .event_manager import EventManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandler:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
||||
This centralizes the business logic for handling specific events,
|
||||
keeping it separate from the routing and collection infrastructure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: "EventManager",
|
||||
edge_processor: "EdgeProcessor",
|
||||
state_manager: "GraphStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Event manager for collecting events
|
||||
edge_processor: Edge processor for edge traversal
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._edge_processor = edge_processor
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
||||
def dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
Handle any node event by dispatching to the appropriate handler.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
@_dispatch.register(NodeRunIterationStartedEvent)
|
||||
@_dispatch.register(NodeRunIterationNextEvent)
|
||||
@_dispatch.register(NodeRunIterationSucceededEvent)
|
||||
@_dispatch.register(NodeRunIterationFailedEvent)
|
||||
@_dispatch.register(NodeRunLoopStartedEvent)
|
||||
@_dispatch.register(NodeRunLoopNextEvent)
|
||||
@_dispatch.register(NodeRunLoopSucceededEvent)
|
||||
@_dispatch.register(NodeRunLoopFailedEvent)
|
||||
@_dispatch.register(NodeRunAgentLogEvent)
|
||||
def _(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle node started event.
|
||||
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Handle stream chunk event with full processing.
|
||||
|
||||
Args:
|
||||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle node success by coordinating subsystems.
|
||||
|
||||
This method coordinates between different subsystems to process
|
||||
node completion, handle edges, and trigger downstream execution.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle node failure using error handler.
|
||||
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.dispatch(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
Handle node exception event (fail-branch strategy).
|
||||
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
Handle node retry event.
|
||||
|
||||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
|
||||
else:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
193
api/core/workflow/graph_engine/event_management/event_manager.py
Normal file
193
api/core/workflow/graph_engine/event_management/event_manager.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Unified event manager for collecting and emitting events.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock implementation that allows multiple concurrent readers
|
||||
but only one writer at a time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
_ = self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
def read_lock(self) -> "ReadLockContext":
|
||||
"""Return a context manager for read locking."""
|
||||
return ReadLockContext(self)
|
||||
|
||||
def write_lock(self) -> "WriteLockContext":
|
||||
"""Return a context manager for write locking."""
|
||||
return WriteLockContext(self)
|
||||
|
||||
|
||||
@final
|
||||
class ReadLockContext:
|
||||
"""Context manager for read locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "ReadLockContext":
|
||||
self._lock.acquire_read()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_read()
|
||||
|
||||
|
||||
@final
|
||||
class WriteLockContext:
|
||||
"""Context manager for write locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "WriteLockContext":
|
||||
self._lock.acquire_write()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_write()
|
||||
|
||||
|
||||
@final
|
||||
class EventManager:
|
||||
"""
|
||||
Unified event manager that collects, buffers, and emits events.
|
||||
|
||||
This class combines event collection with event emission, providing
|
||||
thread-safe event management with support for notifying layers and
|
||||
streaming events to external consumers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event manager."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
|
||||
"""
|
||||
Set the layers to notify on event collection.
|
||||
|
||||
Args:
|
||||
layers: List of layers to notify
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get new events starting from a specific index.
|
||||
|
||||
Args:
|
||||
start_index: The index to start from
|
||||
|
||||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def _event_count(self) -> int:
|
||||
"""
|
||||
Get the current count of collected events.
|
||||
|
||||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the event emission generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self._event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self._get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Notify all layers of an event.
|
||||
|
||||
Layer exceptions are caught and logged to prevent disrupting collection.
|
||||
|
||||
Args:
|
||||
event: The event to send to layers
|
||||
"""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_event(event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors during collection
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
288
api/core/workflow/graph_engine/graph_state_manager.py
Normal file
288
api/core/workflow/graph_engine/graph_state_manager.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Graph state manager that combines node, edge, and execution tracking.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class GraphStateManager:
|
||||
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
|
||||
"""
|
||||
Initialize the state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self._graph = graph
|
||||
self._ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
self._executing_nodes: set[str] = set()
|
||||
|
||||
# ============= Node State Operations =============
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self._ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
|
||||
# ============= Execution Tracking Operations =============
|
||||
|
||||
def start_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def finish_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def get_executing_count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear_executing(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
|
||||
# ============= Composite Operations =============
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if graph execution is complete.
|
||||
|
||||
Execution is complete when:
|
||||
- Ready queue is empty
|
||||
- No nodes are executing
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
Get the current depth of the ready queue.
|
||||
|
||||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self._ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
14
api/core/workflow/graph_engine/graph_traversal/__init__.py
Normal file
14
api/core/workflow/graph_engine/graph_traversal/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Graph traversal subsystem for graph engine.
|
||||
|
||||
This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"EdgeProcessor",
|
||||
"SkipPropagator",
|
||||
]
|
||||
201
api/core/workflow/graph_engine/graph_traversal/edge_processor.py
Normal file
201
api/core/workflow/graph_engine/graph_traversal/edge_processor.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
selected_handle: For branch nodes, the selected edge handle
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no edge was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
self._process_skipped_edge(edge)
|
||||
|
||||
# Process selected edges
|
||||
for edge in selected_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
Args:
|
||||
edge: The edge to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
Mark edge as skipped.
|
||||
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
||||
When a node is skipped, this ensures all downstream nodes
|
||||
that depend solely on it are also skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
Recursively propagate skip state from a skipped edge.
|
||||
|
||||
Rules:
|
||||
- If a node has any UNKNOWN incoming edges, stop processing
|
||||
- If all incoming edges are SKIPPED, skip the node and its edges
|
||||
- If any incoming edge is TAKEN, the node may still execute
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
return
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
if edge_states["all_skipped"]:
|
||||
self._propagate_skip_to_node(downstream_node_id)
|
||||
|
||||
def _propagate_skip_to_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node and all its outgoing edges as skipped.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
52
api/core/workflow/graph_engine/layers/README.md
Normal file
52
api/core/workflow/graph_engine/layers/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Layers
|
||||
|
||||
Pluggable middleware for engine extensions.
|
||||
|
||||
## Components
|
||||
|
||||
### Layer (base)
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
|
||||
### DebugLoggingLayer
|
||||
|
||||
Comprehensive execution logging.
|
||||
|
||||
- Configurable detail levels
|
||||
- Tracks execution statistics
|
||||
- Truncates long values
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="INFO",
|
||||
include_outputs=True
|
||||
)
|
||||
|
||||
engine = GraphEngine(graph)
|
||||
engine.add_layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
class MetricsLayer(Layer):
|
||||
def on_event(self, event):
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self.metrics[event.node_id] = event.elapsed_time
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
**DebugLoggingLayer Options:**
|
||||
|
||||
- `level` - Log level (INFO, DEBUG, ERROR)
|
||||
- `include_inputs/outputs` - Log data values
|
||||
- `max_value_length` - Truncate long values
|
||||
16
api/core/workflow/graph_engine/layers/__init__.py
Normal file
16
api/core/workflow/graph_engine/layers/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Layer system for GraphEngine extensibility.
|
||||
|
||||
This module provides the layer infrastructure for extending GraphEngine functionality
|
||||
with middleware-like components that can observe events and interact with execution.
|
||||
"""
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"GraphEngineLayer",
|
||||
]
|
||||
85
api/core/workflow/graph_engine/layers/base.py
Normal file
85
api/core/workflow/graph_engine/layers/base.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Base layer class for GraphEngine extensions.
|
||||
|
||||
This module provides the abstract base class for implementing layers that can
|
||||
intercept and respond to GraphEngine events.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
|
||||
Layers are middleware-like components that can:
|
||||
- Observe all events emitted by the GraphEngine
|
||||
- Access the graph runtime state
|
||||
- Send commands to control execution
|
||||
|
||||
Subclasses should override the constructor to accept configuration parameters,
|
||||
then implement the three lifecycle methods.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
241
api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
241
api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Debug logging layer for GraphEngine.
|
||||
|
||||
This module provides a layer that logs all events and state changes during
|
||||
graph execution for debugging purposes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(GraphEngineLayer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
||||
This layer logs all events with configurable detail levels, helping developers
|
||||
debug workflow execution and understand the flow of events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level: str = "INFO",
|
||||
include_inputs: bool = False,
|
||||
include_outputs: bool = True,
|
||||
include_process_data: bool = False,
|
||||
logger_name: str = "GraphEngine.Debug",
|
||||
max_value_length: int = 500,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the debug logging layer.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
include_inputs: Whether to log node input values
|
||||
include_outputs: Whether to log node output values
|
||||
include_process_data: Whether to log node process data
|
||||
logger_name: Name of the logger to use
|
||||
max_value_length: Maximum length of logged values (truncated if longer)
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.include_inputs = include_inputs
|
||||
self.include_outputs = include_outputs
|
||||
self.include_process_data = include_process_data
|
||||
self.max_value_length = max_value_length
|
||||
|
||||
# Set up logger
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# Track execution stats
|
||||
self.node_count = 0
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
def _truncate_value(self, value: Any) -> str:
|
||||
"""Truncate long values for logging."""
|
||||
str_value = str(value)
|
||||
if len(str_value) > self.max_value_length:
|
||||
return str_value[: self.max_value_length] + "... (truncated)"
|
||||
return str_value
|
||||
|
||||
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
|
||||
"""Format a dictionary or mapping for logging with truncation."""
|
||||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items: list[str] = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
|
||||
# Graph-level events
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.logger.debug("Graph run started event")
|
||||
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.logger.info("✅ Graph run succeeded")
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.error(" Total exceptions: %s", event.exceptions_count)
|
||||
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
|
||||
if event.outputs:
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
|
||||
if self.include_inputs and event.node_run_result.inputs:
|
||||
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
|
||||
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.success_count += 1
|
||||
self.logger.info("✅ Node succeeded: %s", event.node_id)
|
||||
|
||||
if self.include_outputs and event.node_run_result.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
|
||||
|
||||
if self.include_process_data and event.node_run_result.process_data:
|
||||
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.failure_count += 1
|
||||
self.logger.error("❌ Node failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
if event.node_run_result.error:
|
||||
self.logger.error(" Details: %s", event.node_run_result.error)
|
||||
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
self.logger.debug(
|
||||
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self.logger.info("🔁 Iteration started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunIterationSucceededEvent):
|
||||
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunIterationFailedEvent):
|
||||
self.logger.error("❌ Iteration failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
# Loop events
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self.logger.info("🔄 Loop started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunLoopSucceededEvent):
|
||||
self.logger.info("✅ Loop succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunLoopFailedEvent):
|
||||
self.logger.error("❌ Loop failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
else:
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if error:
|
||||
self.logger.error("🔴 GRAPH EXECUTION FAILED")
|
||||
self.logger.error(" Error: %s", error)
|
||||
else:
|
||||
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
|
||||
|
||||
# Log execution statistics
|
||||
self.logger.info("Execution Statistics:")
|
||||
self.logger.info(" Total nodes executed: %s", self.node_count)
|
||||
self.logger.info(" Successful nodes: %s", self.success_count)
|
||||
self.logger.info(" Failed nodes: %s", self.failure_count)
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
150
api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
150
api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Execution limits layer for GraphEngine.
|
||||
|
||||
This layer monitors workflow execution to enforce limits on:
|
||||
- Maximum execution steps
|
||||
- Maximum execution time
|
||||
|
||||
When limits are exceeded, the layer automatically aborts execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
Monitors:
|
||||
- Step count: Tracks number of node executions
|
||||
- Time limit: Monitors total execution time
|
||||
|
||||
Automatically aborts execution when limits are exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int, max_time: int) -> None:
|
||||
"""
|
||||
Initialize the execution limits layer.
|
||||
|
||||
Args:
|
||||
max_steps: Maximum number of execution steps allowed
|
||||
max_time: Maximum execution time in seconds allowed
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# State tracking
|
||||
self._execution_started = False
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
self.step_count = 0
|
||||
self._execution_started = True
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
Monitors execution progress and enforces limits.
|
||||
"""
|
||||
if not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Track step count for node execution events
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self.step_count += 1
|
||||
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
|
||||
|
||||
# Check step limit when node execution completes
|
||||
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
|
||||
if self._reached_step_limitation():
|
||||
self._send_abort_command(LimitType.STEP_LIMIT)
|
||||
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
|
||||
|
||||
def _reached_step_limitation(self) -> bool:
|
||||
"""Check if step count limit has been exceeded."""
|
||||
return self.step_count > self.max_steps
|
||||
|
||||
def _reached_time_limitation(self) -> bool:
|
||||
"""Check if time limit has been exceeded."""
|
||||
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
|
||||
|
||||
def _send_abort_command(self, limit_type: LimitType) -> None:
|
||||
"""
|
||||
Send abort command due to limit violation.
|
||||
|
||||
Args:
|
||||
limit_type: Type of limit exceeded
|
||||
"""
|
||||
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Format detailed reason message
|
||||
if limit_type == LimitType.STEP_LIMIT:
|
||||
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
try:
|
||||
# Send abort command to the engine
|
||||
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
|
||||
self.command_channel.send_command(abort_command)
|
||||
|
||||
# Mark that abort has been sent to prevent duplicate commands
|
||||
self._abort_sent = True
|
||||
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command: %s")
|
||||
50
api/core/workflow/graph_engine/manager.py
Normal file
50
api/core/workflow/graph_engine/manager.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
Args:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Create Redis channel for this task
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
# Create and send abort command
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
|
||||
try:
|
||||
channel.send_command(abort_command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy stop flag mechanism will still work
|
||||
pass
|
||||
14
api/core/workflow/graph_engine/orchestration/__init__.py
Normal file
14
api/core/workflow/graph_engine/orchestration/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Orchestration subsystem for graph engine.
|
||||
|
||||
This package coordinates the overall execution flow between
|
||||
different subsystems.
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
__all__ = [
|
||||
"Dispatcher",
|
||||
"ExecutionCoordinator",
|
||||
]
|
||||
104
api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
104
api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Main dispatcher for processing events from workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
|
||||
This runs in a separate thread and coordinates event processing
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
|
||||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting unhandled events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
event_emitter: Optional event manager to signal completion
|
||||
"""
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self._execution_coordinator.check_commands()
|
||||
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self._execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self._execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventManager
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
||||
This provides high-level coordination methods used by the
|
||||
dispatcher to manage execution state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def check_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self._command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if execution is complete.
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
Mark execution as failed.
|
||||
|
||||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
41
api/core/workflow/graph_engine/protocols/command_channel.py
Normal file
41
api/core/workflow/graph_engine/protocols/command_channel.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
CommandChannel protocol for GraphEngine command communication.
|
||||
|
||||
This protocol defines the interface for sending and receiving commands
|
||||
to/from a GraphEngine instance, supporting both local and distributed scenarios.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
class CommandChannel(Protocol):
|
||||
"""
|
||||
Protocol for bidirectional command communication with GraphEngine.
|
||||
|
||||
Since each GraphEngine instance processes only one workflow execution,
|
||||
this channel is dedicated to that single execution.
|
||||
"""
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch pending commands for this GraphEngine instance.
|
||||
|
||||
Called by GraphEngine to poll for commands that need to be processed.
|
||||
|
||||
Returns:
|
||||
List of pending commands (may be empty)
|
||||
"""
|
||||
...
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to be processed by this GraphEngine instance.
|
||||
|
||||
Called by external systems to send control commands to the running workflow.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
...
|
||||
12
api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
12
api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Ready queue implementations for GraphEngine.
|
||||
|
||||
This package contains the protocol and implementations for managing
|
||||
the queue of nodes ready for execution.
|
||||
"""
|
||||
|
||||
from .factory import create_ready_queue_from_state
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]
|
||||
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Factory for creating ReadyQueue instances from serialized state.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueueState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocol import ReadyQueue
|
||||
|
||||
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||
"""
|
||||
Create a ReadyQueue instance from a serialized state.
|
||||
|
||||
Args:
|
||||
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
|
||||
|
||||
Returns:
|
||||
A ReadyQueue instance initialized with the given state
|
||||
|
||||
Raises:
|
||||
ValueError: If the queue type is unknown or version is unsupported
|
||||
"""
|
||||
if state.type == "InMemoryReadyQueue":
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
|
||||
queue = InMemoryReadyQueue()
|
||||
# Always pass as JSON string to loads()
|
||||
queue.loads(state.model_dump_json())
|
||||
return queue
|
||||
else:
|
||||
raise ValueError(f"Unknown ready queue type: {state.type}")
|
||||
140
api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
140
api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
In-memory implementation of the ReadyQueue protocol.
|
||||
|
||||
This implementation wraps Python's standard queue.Queue and adds
|
||||
serialization capabilities for state storage.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import final
|
||||
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryReadyQueue(ReadyQueue):
|
||||
"""
|
||||
In-memory ready queue implementation with serialization support.
|
||||
|
||||
This implementation uses Python's queue.Queue internally and provides
|
||||
methods to serialize and restore the queue state.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 0) -> None:
|
||||
"""
|
||||
Initialize the in-memory ready queue.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum size of the queue (0 for unlimited)
|
||||
"""
|
||||
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
self._queue.put(item)
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
if timeout is None:
|
||||
return self._queue.get(block=True)
|
||||
return self._queue.get(timeout=timeout)
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
self._queue.task_done()
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
return self._queue.empty()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
"""
|
||||
# Extract all items from the queue without removing them
|
||||
items: list[str] = []
|
||||
temp_items: list[str] = []
|
||||
|
||||
# Drain the queue temporarily to get all items
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
temp_items.append(item)
|
||||
items.append(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Put items back in the same order
|
||||
for item in temp_items:
|
||||
self._queue.put(item)
|
||||
|
||||
state = ReadyQueueState(
|
||||
type="InMemoryReadyQueue",
|
||||
version="1.0",
|
||||
items=items,
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
state = ReadyQueueState.model_validate_json(data)
|
||||
|
||||
if state.type != "InMemoryReadyQueue":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported version: {state.version}")
|
||||
|
||||
# Clear the current queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Restore items
|
||||
for item in state.items:
|
||||
self._queue.put(item)
|
||||
104
api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
104
api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
ReadyQueue protocol for GraphEngine node execution queue.
|
||||
|
||||
This protocol defines the interface for managing the queue of nodes ready
|
||||
for execution, supporting both in-memory and persistent storage scenarios.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReadyQueueState(BaseModel):
|
||||
"""
|
||||
Pydantic model for serialized ready queue state.
|
||||
|
||||
This defines the structure of the data returned by dumps()
|
||||
and expected by loads() for ready queue serialization.
|
||||
"""
|
||||
|
||||
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
|
||||
version: str = Field(description="Serialization format version")
|
||||
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
|
||||
|
||||
|
||||
class ReadyQueue(Protocol):
|
||||
"""
|
||||
Protocol for managing nodes ready for execution in GraphEngine.
|
||||
|
||||
This protocol defines the interface that any ready queue implementation
|
||||
must provide, enabling both in-memory queues and persistent queues
|
||||
that can be serialized for state storage.
|
||||
"""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
that can be persisted and later restored
|
||||
"""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
ResponseStreamCoordinator - Coordinates streaming output from response nodes
|
||||
|
||||
This component manages response streaming sessions and ensures ordered streaming
|
||||
of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
@@ -0,0 +1,696 @@
|
||||
"""
|
||||
Main ResponseStreamCoordinator implementation.
|
||||
|
||||
This module contains the public ResponseStreamCoordinator class that manages
|
||||
response streaming sessions and ensures ordered streaming of responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Literal, TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions
|
||||
NodeID: TypeAlias = str
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
class ResponseSessionState(BaseModel):
|
||||
"""Serializable representation of a response session."""
|
||||
|
||||
node_id: str
|
||||
index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class StreamBufferState(BaseModel):
|
||||
"""Serializable representation of buffered stream chunks."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamPositionState(BaseModel):
|
||||
"""Serializable representation for stream read positions."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
position: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorState(BaseModel):
|
||||
"""Serialized snapshot of ResponseStreamCoordinator."""
|
||||
|
||||
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
|
||||
version: str = Field(default="1.0")
|
||||
response_nodes: Sequence[str] = Field(default_factory=list)
|
||||
active_session: ResponseSessionState | None = None
|
||||
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
node_execution_ids: dict[str, str] = Field(default_factory=dict)
|
||||
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
|
||||
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
|
||||
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
|
||||
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
Args:
|
||||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self._variable_pool = variable_pool
|
||||
self._graph = graph
|
||||
self._active_session: ResponseSession | None = None
|
||||
self._waiting_sessions: deque[ResponseSession] = deque()
|
||||
self._lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
self._stream_positions: dict[tuple[str, ...], int] = {}
|
||||
self._closed_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
|
||||
# Store paths for each response node
|
||||
self._paths_maps: dict[NodeID, list[Path]] = {}
|
||||
|
||||
# Track node execution IDs and types for proper event forwarding
|
||||
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
|
||||
|
||||
# Track response sessions to ensure only one per node
|
||||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self._lock:
|
||||
if response_node_id in self._response_nodes:
|
||||
return
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
paths_map = self._build_paths_map(response_node_id)
|
||||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
|
||||
"""Track the execution ID for a node when it starts executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self._lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
"""Get the execution ID for a node, creating one if it doesn't exist.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self._lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
|
||||
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
|
||||
"""
|
||||
Build a paths map for a response node by finding all paths from root node
|
||||
to the response node, recording branch edges along each path.
|
||||
|
||||
Args:
|
||||
response_node_id: ID of the response node to analyze
|
||||
|
||||
Returns:
|
||||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self._graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
# Collect all variable selectors from the template
|
||||
variable_selectors: set[tuple[str, ...]] = set()
|
||||
for segment in template.segments:
|
||||
if isinstance(segment, VariableSegment):
|
||||
variable_selectors.add(tuple(segment.selector[:2]))
|
||||
|
||||
# Step 1: Find all complete paths from root to response node
|
||||
all_complete_paths: list[list[EdgeID]] = []
|
||||
|
||||
def find_paths(
|
||||
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
|
||||
) -> None:
|
||||
"""Recursively find all paths from current node to target node."""
|
||||
if current_node_id == target_node_id:
|
||||
# Found a complete path, store it
|
||||
all_complete_paths.append(current_path.copy())
|
||||
return
|
||||
|
||||
# Mark as visited to avoid cycles
|
||||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
|
||||
# Skip if already visited in this path
|
||||
if next_node_id not in visited:
|
||||
# Add edge to path and recurse
|
||||
new_path = current_path + [edge_id]
|
||||
find_paths(next_node_id, target_node_id, new_path, visited.copy())
|
||||
|
||||
# Start searching from root node
|
||||
find_paths(root_node_id, response_node_id, [], set())
|
||||
|
||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self._graph.edges[edge_id]
|
||||
source_node = self._graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch/container (original behavior)
|
||||
if source_node.execution_type in {
|
||||
NodeExecutionType.BRANCH,
|
||||
NodeExecutionType.CONTAINER,
|
||||
} or source_node.blocks_variable_output(variable_selectors):
|
||||
blocking_edges.append(edge_id)
|
||||
|
||||
# Keep the path even if it's empty
|
||||
filtered_paths.append(Path(edges=blocking_edges))
|
||||
|
||||
return filtered_paths
|
||||
|
||||
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Handle when an edge is taken (selected by a branch node).
|
||||
|
||||
This method updates the paths for all response nodes by removing
|
||||
the taken edge. If any response node has an empty path after removal,
|
||||
it means the node is now deterministically reachable and should start.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge that was taken
|
||||
|
||||
Returns:
|
||||
List of events to emit from starting new sessions
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self._lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
continue
|
||||
|
||||
paths = self._paths_maps[response_node_id]
|
||||
has_reachable_path = False
|
||||
|
||||
# Update each path by removing the taken edge
|
||||
for path in paths:
|
||||
# Remove the taken edge from this path
|
||||
path.remove_edge(edge_id)
|
||||
|
||||
# Check if this path is now empty (node is reachable)
|
||||
if path.is_empty():
|
||||
has_reachable_path = True
|
||||
|
||||
# If node is now reachable (has empty path), start/queue session
|
||||
if has_reachable_path:
|
||||
# Pass the node_id to the activation method
|
||||
# The method will handle checking and removing from map
|
||||
events.extend(self._active_or_queue_session(response_node_id))
|
||||
return events
|
||||
|
||||
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Start a session immediately if no active session, otherwise queue it.
|
||||
Only activates sessions that exist in the _response_sessions map.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the response node to activate
|
||||
|
||||
Returns:
|
||||
List of events from flush attempt if session started immediately
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Get the session from our map (only activate if it exists)
|
||||
session = self._response_sessions.get(node_id)
|
||||
if not session:
|
||||
return events
|
||||
|
||||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self._active_session is None:
|
||||
self._active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self._waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self._close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self._variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self._graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
"""Process a variable segment. Returns (events, is_complete).
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
is_complete = False
|
||||
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
current_response_node = self._graph.nodes[self._active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
|
||||
chunk=segment.text,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if not self._active_session:
|
||||
return []
|
||||
|
||||
template = self._active_session.template
|
||||
response_node_id = self._active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self._active_session.index < len(template.segments):
|
||||
segment = template.segments[self._active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self._graph.nodes:
|
||||
source_node = self._graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self._active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
events.extend(segment_events)
|
||||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self._active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
||||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self._active_session.index += 1
|
||||
|
||||
if self._active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
|
||||
return events
|
||||
|
||||
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
End the active session for a response node.
|
||||
Automatically starts the next waiting session if available.
|
||||
|
||||
Args:
|
||||
node_id: ID of the response node ending its session
|
||||
|
||||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self._lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self._active_session and self._active_session.node_id == node_id:
|
||||
self._active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self._waiting_sessions:
|
||||
next_session = self._waiting_sessions.popleft()
|
||||
self._active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key in self._closed_streams:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
self._stream_buffers[key] = []
|
||||
self._stream_positions[key] = 0
|
||||
|
||||
self._stream_buffers[key].append(event)
|
||||
|
||||
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
|
||||
"""
|
||||
Pop the next unread stream chunk from the buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return None
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
buffer = self._stream_buffers[key]
|
||||
|
||||
if position >= len(buffer):
|
||||
return None
|
||||
|
||||
event = buffer[position]
|
||||
self._stream_positions[key] = position + 1
|
||||
return event
|
||||
|
||||
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return False
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
return position < len(self._stream_buffers[key])
|
||||
|
||||
def _close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = tuple(selector)
|
||||
self._closed_streams.add(key)
|
||||
|
||||
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
|
||||
"""Convert an in-memory session into its serializable form."""
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
return ResponseSessionState(node_id=session.node_id, index=session.index)
|
||||
|
||||
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
|
||||
"""Rebuild a response session from serialized data."""
|
||||
|
||||
node = self._graph.nodes.get(session_state.node_id)
|
||||
if node is None:
|
||||
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
session.index = session_state.index
|
||||
return session
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state to JSON."""
|
||||
|
||||
with self._lock:
|
||||
state = ResponseStreamCoordinatorState(
|
||||
response_nodes=sorted(self._response_nodes),
|
||||
active_session=self._serialize_session(self._active_session),
|
||||
waiting_sessions=[
|
||||
session_state
|
||||
for session in list(self._waiting_sessions)
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
pending_sessions=[
|
||||
session_state
|
||||
for _, session in sorted(self._response_sessions.items())
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
|
||||
paths_map={
|
||||
node_id: [path.edges.copy() for path in paths]
|
||||
for node_id, paths in sorted(self._paths_maps.items())
|
||||
},
|
||||
stream_buffers=[
|
||||
StreamBufferState(
|
||||
selector=selector,
|
||||
events=[event.model_copy(deep=True) for event in events],
|
||||
)
|
||||
for selector, events in sorted(self._stream_buffers.items())
|
||||
],
|
||||
stream_positions=[
|
||||
StreamPositionState(selector=selector, position=position)
|
||||
for selector, position in sorted(self._stream_positions.items())
|
||||
],
|
||||
closed_streams=sorted(self._closed_streams),
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from JSON."""
|
||||
|
||||
state = ResponseStreamCoordinatorState.model_validate_json(data)
|
||||
|
||||
if state.type != "ResponseStreamCoordinator":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
with self._lock:
|
||||
self._response_nodes = set(state.response_nodes)
|
||||
self._paths_maps = {
|
||||
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
|
||||
for node_id, paths in state.paths_map.items()
|
||||
}
|
||||
self._node_execution_ids = dict(state.node_execution_ids)
|
||||
|
||||
self._stream_buffers = {
|
||||
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
|
||||
for buffer in state.stream_buffers
|
||||
}
|
||||
self._stream_positions = {
|
||||
tuple(position.selector): position.position for position in state.stream_positions
|
||||
}
|
||||
for selector in self._stream_buffers:
|
||||
self._stream_positions.setdefault(selector, 0)
|
||||
|
||||
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
|
||||
|
||||
self._waiting_sessions = deque(
|
||||
self._session_from_state(session_state) for session_state in state.waiting_sessions
|
||||
)
|
||||
self._response_sessions = {
|
||||
session_state.node_id: self._session_from_state(session_state)
|
||||
for session_state in state.pending_sessions
|
||||
}
|
||||
self._active_session = self._session_from_state(state.active_session) if state.active_session else None
|
||||
35
api/core/workflow/graph_engine/response_coordinator/path.py
Normal file
35
api/core/workflow/graph_engine/response_coordinator/path.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Internal path representation for response coordinator.
|
||||
|
||||
This module contains the private Path class used internally by ResponseStreamCoordinator
|
||||
to track execution paths to response nodes.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
"""
|
||||
Represents a path of branch edges that must be taken to reach a response node.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
edges: list[EdgeID] = field(default_factory=list[EdgeID])
|
||||
|
||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||
"""Check if this path contains the given edge."""
|
||||
return edge_id in self.edges
|
||||
|
||||
def remove_edge(self, edge_id: EdgeID) -> None:
|
||||
"""Remove the given edge from this path in place."""
|
||||
if self.contains_edge(edge_id):
|
||||
self.edges.remove(edge_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the path has no edges (node is reachable)."""
|
||||
return len(self.edges) == 0
|
||||
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
Represents an active response streaming session.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
template: Template # Template object from the response node
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> "ResponseSession":
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
|
||||
Args:
|
||||
node: Must be either an AnswerNode or EndNode instance
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not an AnswerNode or EndNode
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all segments in the template have been processed."""
|
||||
return self.index >= len(self.template.segments)
|
||||
142
api/core/workflow/graph_engine/worker.py
Normal file
142
api/core/workflow/graph_engine/worker.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Worker - Thread implementation for queue-based node execution
|
||||
|
||||
Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
|
||||
Workers continuously pull node IDs from the ready_queue, execute the
|
||||
corresponding nodes, and push the resulting events to the event_queue
|
||||
for the dispatcher to process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue containing node IDs ready for execution
|
||||
event_queue: Queue for pushing execution events
|
||||
graph: Graph containing nodes to execute
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the worker is currently idle."""
|
||||
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
|
||||
return (time.time() - self._last_task_time) > 0.2
|
||||
|
||||
@property
|
||||
def idle_duration(self) -> float:
|
||||
"""Get the duration in seconds since the worker last processed a task."""
|
||||
return time.time() - self._last_task_time
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
"""Get the worker's ID."""
|
||||
return self._worker_id
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
|
||||
Continuously pulls node IDs from ready_queue, executes them,
|
||||
and pushes events to event_queue until stopped.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="unknown",
|
||||
node_type=NodeType.CODE,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
Execute a single node and handle its events.
|
||||
|
||||
Args:
|
||||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
12
api/core/workflow/graph_engine/worker_management/__init__.py
Normal file
12
api/core/workflow/graph_engine/worker_management/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Worker management subsystem for graph engine.
|
||||
|
||||
This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = [
|
||||
"WorkerPool",
|
||||
]
|
||||
291
api/core/workflow/graph_engine/worker_management/worker_pool.py
Normal file
291
api/core/workflow/graph_engine/worker_management/worker_pool.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Simple worker pool that consolidates functionality.
|
||||
|
||||
This is a simpler implementation that merges WorkerPool, ActivityTracker,
|
||||
DynamicScaler, and WorkerFactory into a single class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..ready_queue import ReadyQueue
|
||||
from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Simple worker pool with integrated management.
|
||||
|
||||
This class consolidates all worker management functionality into
|
||||
a single, simpler implementation without excessive abstraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the simple worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue for nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Seconds before scaling down idle workers
|
||||
"""
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
|
||||
# Scaling parameters with defaults
|
||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
|
||||
# Worker management
|
||||
self._workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with (auto-calculated if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Calculate initial worker count
|
||||
if initial_count is None:
|
||||
node_count = len(self._graph.nodes)
|
||||
if node_count < 10:
|
||||
initial_count = self._min_workers
|
||||
elif node_count < 50:
|
||||
initial_count = min(self._min_workers + 1, self._max_workers)
|
||||
else:
|
||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
||||
|
||||
logger.debug(
|
||||
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
||||
initial_count,
|
||||
node_count,
|
||||
self._min_workers,
|
||||
self._max_workers,
|
||||
)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._create_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
worker_count = len(self._workers)
|
||||
|
||||
if worker_count > 0:
|
||||
logger.debug("Stopping worker pool: %d workers", worker_count)
|
||||
|
||||
# Stop all workers
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
def _create_worker(self) -> None:
|
||||
"""Create and start a new worker."""
|
||||
worker_id = self._worker_counter
|
||||
self._worker_counter += 1
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
|
||||
"""Remove a specific worker from the pool."""
|
||||
# Stop the worker
|
||||
worker.stop()
|
||||
|
||||
# Wait for it to finish
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
# Remove from list
|
||||
if worker in self._workers:
|
||||
self._workers.remove(worker)
|
||||
|
||||
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
|
||||
"""
|
||||
Try to scale up workers if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
|
||||
Returns:
|
||||
True if scaled up, False otherwise
|
||||
"""
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
old_count = current_count
|
||||
self._create_worker()
|
||||
|
||||
logger.debug(
|
||||
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
queue_depth,
|
||||
self._scale_up_threshold,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
|
||||
"""
|
||||
Try to scale down workers if we have excess capacity.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
active_count: Number of active workers
|
||||
idle_count: Number of idle workers
|
||||
|
||||
Returns:
|
||||
True if scaled down, False otherwise
|
||||
"""
|
||||
# Skip if we're at minimum or have no idle workers
|
||||
if current_count <= self._min_workers or idle_count == 0:
|
||||
return False
|
||||
|
||||
# Check if we have excess capacity
|
||||
has_excess_capacity = (
|
||||
queue_depth <= active_count # Active workers can handle current queue
|
||||
or idle_count > active_count # More idle than active workers
|
||||
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
|
||||
)
|
||||
|
||||
if not has_excess_capacity:
|
||||
return False
|
||||
|
||||
# Find and remove idle workers that have been idle long enough
|
||||
workers_to_remove: list[tuple[Worker, int]] = []
|
||||
|
||||
for worker in self._workers:
|
||||
# Check if worker is idle and has exceeded idle time threshold
|
||||
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
|
||||
# Don't remove if it would leave us unable to handle the queue
|
||||
remaining_workers = current_count - len(workers_to_remove) - 1
|
||||
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
|
||||
workers_to_remove.append((worker, worker.worker_id))
|
||||
# Only remove one worker per check to avoid aggressive scaling
|
||||
break
|
||||
|
||||
# Remove idle workers if any found
|
||||
if workers_to_remove:
|
||||
old_count = current_count
|
||||
for worker, worker_id in workers_to_remove:
|
||||
self._remove_worker(worker, worker_id)
|
||||
|
||||
logger.debug(
|
||||
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
|
||||
"queue_depth=%d, active=%d, idle=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
len(workers_to_remove),
|
||||
self._scale_down_idle_time,
|
||||
queue_depth,
|
||||
active_count,
|
||||
idle_count - len(workers_to_remove),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Count active vs idle workers by querying their state directly
|
||||
idle_count = sum(1 for worker in self._workers if worker.is_idle)
|
||||
active_count = current_count - idle_count
|
||||
|
||||
# Try to scale up if queue is backing up
|
||||
self._try_scale_up(queue_depth, current_count)
|
||||
|
||||
# Try to scale down if we have excess capacity
|
||||
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self._workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
Get pool status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self._workers),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"min_workers": self._min_workers,
|
||||
"max_workers": self._max_workers,
|
||||
}
|
||||
Reference in New Issue
Block a user