mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 23:04:12 +00:00
refactor(graph_engine): inline output_registry into response_coordinator
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -37,7 +37,6 @@ type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
output_registry
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ from .event_management import EventCollector, EventEmitter, EventHandlerRegistry
|
||||
from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator
|
||||
from .layers.base import Layer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .output_registry import OutputRegistry
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .state_management import UnifiedStateManager
|
||||
@@ -122,8 +121,9 @@ class GraphEngine:
|
||||
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)
|
||||
|
||||
# Response coordination
|
||||
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
|
||||
self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph)
|
||||
self.response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph
|
||||
)
|
||||
|
||||
# Event management
|
||||
self.event_collector = EventCollector()
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)
|
||||
|
||||
This component provides thread-safe storage and retrieval of node outputs,
|
||||
supporting both scalar values and streaming chunks with proper state management.
|
||||
"""
|
||||
|
||||
from .registry import OutputRegistry
|
||||
|
||||
__all__ = ["OutputRegistry"]
|
||||
@@ -1,148 +0,0 @@
|
||||
"""
|
||||
Main OutputRegistry implementation.
|
||||
|
||||
This module contains the public OutputRegistry class that provides
|
||||
thread-safe storage for node outputs.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .stream import Stream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class OutputRegistry:
|
||||
"""
|
||||
Thread-safe registry for storing and retrieving node outputs.
|
||||
|
||||
Supports both scalar values and streaming chunks with proper state management.
|
||||
All operations are thread-safe using internal locking.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
"""Initialize empty registry with thread-safe storage."""
|
||||
self._lock = RLock()
|
||||
self._scalars = variable_pool
|
||||
self._streams: dict[tuple[str, ...], Stream] = {}
|
||||
|
||||
def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
|
||||
"""Convert selector list to tuple key for internal storage."""
|
||||
return tuple(selector)
|
||||
|
||||
def set_scalar(
|
||||
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Set a scalar value for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the output location
|
||||
value: The scalar value to store
|
||||
"""
|
||||
with self._lock:
|
||||
self._scalars.add(selector, value)
|
||||
|
||||
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
|
||||
"""
|
||||
Get a scalar value for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the output location
|
||||
|
||||
Returns:
|
||||
The stored Variable object, or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
return self._scalars.get(selector)
|
||||
|
||||
def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
|
||||
"""
|
||||
Append a NodeRunStreamChunkEvent to the stream for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
self._streams[key] = Stream()
|
||||
|
||||
try:
|
||||
self._streams[key].append(event)
|
||||
except ValueError:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return None
|
||||
|
||||
return self._streams[key].pop_next()
|
||||
|
||||
def has_unread(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return False
|
||||
|
||||
return self._streams[key].has_unread()
|
||||
|
||||
def close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
self._streams[key] = Stream()
|
||||
self._streams[key].close()
|
||||
|
||||
def stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return False
|
||||
return self._streams[key].is_closed
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
Internal stream implementation for OutputRegistry.
|
||||
|
||||
This module contains the private Stream class used internally by OutputRegistry
|
||||
to manage streaming data chunks.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class Stream:
|
||||
"""
|
||||
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
|
||||
|
||||
This class encapsulates stream-specific data and operations,
|
||||
including event storage, read position tracking, and closed state.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty stream."""
|
||||
self.events: list[NodeRunStreamChunkEvent] = []
|
||||
self.read_position: int = 0
|
||||
self.is_closed: bool = False
|
||||
|
||||
def append(self, event: "NodeRunStreamChunkEvent") -> None:
|
||||
"""
|
||||
Append a NodeRunStreamChunkEvent to the stream.
|
||||
|
||||
Args:
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
if self.is_closed:
|
||||
raise ValueError("Cannot append to a closed stream")
|
||||
self.events.append(event)
|
||||
|
||||
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
if self.read_position >= len(self.events):
|
||||
return None
|
||||
|
||||
event = self.events[self.read_position]
|
||||
self.read_position += 1
|
||||
return event
|
||||
|
||||
def has_unread(self) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
return self.read_position < len(self.events)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Mark the stream as closed (no more chunks can be appended)."""
|
||||
self.is_closed = True
|
||||
@@ -12,12 +12,12 @@ from threading import RLock
|
||||
from typing import TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
|
||||
from ..output_registry import OutputRegistry
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
@@ -36,20 +36,25 @@ class ResponseStreamCoordinator:
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
"""
|
||||
Initialize coordinator with output registry.
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
Args:
|
||||
registry: OutputRegistry instance for accessing node outputs
|
||||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self.registry = registry
|
||||
self.variable_pool = variable_pool
|
||||
self.graph = graph
|
||||
self.active_session: ResponseSession | None = None
|
||||
self.waiting_sessions: deque[ResponseSession] = deque()
|
||||
self.lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
self._stream_positions: dict[tuple[str, ...], int] = {}
|
||||
self._closed_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
|
||||
@@ -256,15 +261,15 @@ class ResponseStreamCoordinator:
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.registry.append_chunk(event.selector, event)
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self.registry.close_stream(event.selector)
|
||||
self._close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
|
||||
# self.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
return []
|
||||
|
||||
@@ -327,8 +332,8 @@ class ResponseStreamCoordinator:
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self.registry.has_unread(segment.selector):
|
||||
if event := self.registry.pop_chunk(segment.selector):
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
@@ -349,12 +354,12 @@ class ResponseStreamCoordinator:
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self.registry.stream_closed(segment.selector)
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self.registry.get_scalar(segment.selector):
|
||||
elif value := self.variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
@@ -464,3 +469,93 @@ class ResponseStreamCoordinator:
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key in self._closed_streams:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
self._stream_buffers[key] = []
|
||||
self._stream_positions[key] = 0
|
||||
|
||||
self._stream_buffers[key].append(event)
|
||||
|
||||
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
|
||||
"""
|
||||
Pop the next unread stream chunk from the buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return None
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
buffer = self._stream_buffers[key]
|
||||
|
||||
if position >= len(buffer):
|
||||
return None
|
||||
|
||||
event = buffer[position]
|
||||
self._stream_positions[key] = position + 1
|
||||
return event
|
||||
|
||||
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return False
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
return position < len(self._stream_buffers[key])
|
||||
|
||||
def _close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = tuple(selector)
|
||||
self._closed_streams.add(key)
|
||||
|
||||
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.output_registry import OutputRegistry
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
class TestOutputRegistry:
|
||||
def test_scalar_operations(self):
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Test setting and getting scalar
|
||||
registry.set_scalar(["node1", "output"], "test_value")
|
||||
|
||||
segment = registry.get_scalar(["node1", "output"])
|
||||
assert segment
|
||||
assert segment.text == "test_value"
|
||||
|
||||
# Test getting non-existent scalar
|
||||
assert registry.get_scalar(["non_existent"]) is None
|
||||
|
||||
def test_stream_operations(self):
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Create test events
|
||||
event1 = NodeRunStreamChunkEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="node1",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node1", "stream"],
|
||||
chunk="chunk1",
|
||||
is_final=False,
|
||||
)
|
||||
event2 = NodeRunStreamChunkEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="node1",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node1", "stream"],
|
||||
chunk="chunk2",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Test appending events
|
||||
registry.append_chunk(["node1", "stream"], event1)
|
||||
registry.append_chunk(["node1", "stream"], event2)
|
||||
|
||||
# Test has_unread
|
||||
assert registry.has_unread(["node1", "stream"]) is True
|
||||
|
||||
# Test popping events
|
||||
popped_event1 = registry.pop_chunk(["node1", "stream"])
|
||||
assert popped_event1 == event1
|
||||
assert popped_event1.chunk == "chunk1"
|
||||
|
||||
popped_event2 = registry.pop_chunk(["node1", "stream"])
|
||||
assert popped_event2 == event2
|
||||
assert popped_event2.chunk == "chunk2"
|
||||
|
||||
assert registry.pop_chunk(["node1", "stream"]) is None
|
||||
|
||||
# Test has_unread after popping all
|
||||
assert registry.has_unread(["node1", "stream"]) is False
|
||||
|
||||
def test_stream_closing(self):
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Test stream is not closed initially
|
||||
assert registry.stream_closed(["node1", "stream"]) is False
|
||||
|
||||
# Test closing stream
|
||||
registry.close_stream(["node1", "stream"])
|
||||
assert registry.stream_closed(["node1", "stream"]) is True
|
||||
|
||||
# Test appending to closed stream raises error
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="node1",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node1", "stream"],
|
||||
chunk="chunk",
|
||||
is_final=False,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Stream node1.stream is already closed"):
|
||||
registry.append_chunk(["node1", "stream"], event)
|
||||
|
||||
def test_thread_safety(self):
|
||||
import threading
|
||||
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
results = []
|
||||
|
||||
def append_chunks(thread_id: int):
|
||||
for i in range(100):
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["stream"],
|
||||
chunk=f"thread{thread_id}_chunk{i}",
|
||||
is_final=False,
|
||||
)
|
||||
registry.append_chunk(["stream"], event)
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=append_chunks, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify all events are present
|
||||
events = []
|
||||
while True:
|
||||
event = registry.pop_chunk(["stream"])
|
||||
if event is None:
|
||||
break
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 500 # 5 threads * 100 events each
|
||||
# Verify the events have the expected chunk content format
|
||||
chunk_texts = [e.chunk for e in events]
|
||||
for i in range(5):
|
||||
for j in range(100):
|
||||
assert f"thread{i}_chunk{j}" in chunk_texts
|
||||
@@ -1,347 +0,0 @@
|
||||
"""Test cases for ResponseStreamCoordinator."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeState, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.output_registry import OutputRegistry
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||
|
||||
|
||||
class TestResponseStreamCoordinator:
|
||||
"""Test cases for ResponseStreamCoordinator."""
|
||||
|
||||
def test_skip_variable_segment_from_skipped_node(self):
|
||||
"""Test that VariableSegments from skipped nodes are properly skipped during try_flush."""
|
||||
# Create mock graph
|
||||
graph = Mock(spec=Graph)
|
||||
|
||||
# Create mock nodes
|
||||
skipped_node = Mock(spec=Node)
|
||||
skipped_node.id = "skipped_node"
|
||||
skipped_node.state = NodeState.SKIPPED
|
||||
skipped_node.node_type = NodeType.LLM
|
||||
|
||||
active_node = Mock(spec=Node)
|
||||
active_node.id = "active_node"
|
||||
active_node.state = NodeState.TAKEN
|
||||
active_node.node_type = NodeType.LLM
|
||||
|
||||
response_node = Mock(spec=AnswerNode)
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
|
||||
# Set up graph nodes dictionary
|
||||
graph.nodes = {"skipped_node": skipped_node, "active_node": active_node, "response_node": response_node}
|
||||
|
||||
# Create output registry with variable pool
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Add some test data to registry for the active node
|
||||
registry.set_scalar(("active_node", "output"), StringSegment(value="Active output"))
|
||||
|
||||
# Create RSC instance
|
||||
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)
|
||||
|
||||
# Create template with segments from both skipped and active nodes
|
||||
template = Template(
|
||||
segments=[
|
||||
VariableSegment(selector=["skipped_node", "output"]),
|
||||
TextSegment(text=" - "),
|
||||
VariableSegment(selector=["active_node", "output"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Create and set active session
|
||||
session = ResponseSession(node_id="response_node", template=template, index=0)
|
||||
rsc.active_session = session
|
||||
|
||||
# Execute try_flush
|
||||
events = rsc.try_flush()
|
||||
|
||||
# Verify that:
|
||||
# 1. The skipped node's variable segment was skipped (index advanced)
|
||||
# 2. The text segment was processed
|
||||
# 3. The active node's variable segment was processed
|
||||
assert len(events) == 2 # TextSegment + VariableSegment from active_node
|
||||
|
||||
# Check that the first event is the text segment
|
||||
assert events[0].chunk == " - "
|
||||
|
||||
# Check that the second event is from the active node
|
||||
assert events[1].chunk == "Active output"
|
||||
assert events[1].selector == ["active_node", "output"]
|
||||
|
||||
# Session should be complete
|
||||
assert session.is_complete()
|
||||
|
||||
def test_process_variable_segment_from_non_skipped_node(self):
|
||||
"""Test that VariableSegments from non-skipped nodes are processed normally."""
|
||||
# Create mock graph
|
||||
graph = Mock(spec=Graph)
|
||||
|
||||
# Create mock nodes
|
||||
active_node1 = Mock(spec=Node)
|
||||
active_node1.id = "node1"
|
||||
active_node1.state = NodeState.TAKEN
|
||||
active_node1.node_type = NodeType.LLM
|
||||
|
||||
active_node2 = Mock(spec=Node)
|
||||
active_node2.id = "node2"
|
||||
active_node2.state = NodeState.TAKEN
|
||||
active_node2.node_type = NodeType.LLM
|
||||
|
||||
response_node = Mock(spec=AnswerNode)
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
|
||||
# Set up graph nodes dictionary
|
||||
graph.nodes = {"node1": active_node1, "node2": active_node2, "response_node": response_node}
|
||||
|
||||
# Create output registry with variable pool
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Add test data to registry
|
||||
registry.set_scalar(("node1", "output"), StringSegment(value="Output 1"))
|
||||
registry.set_scalar(("node2", "output"), StringSegment(value="Output 2"))
|
||||
|
||||
# Create RSC instance
|
||||
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)
|
||||
|
||||
# Create template with segments from active nodes
|
||||
template = Template(
|
||||
segments=[
|
||||
VariableSegment(selector=["node1", "output"]),
|
||||
TextSegment(text=" | "),
|
||||
VariableSegment(selector=["node2", "output"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Create and set active session
|
||||
session = ResponseSession(node_id="response_node", template=template, index=0)
|
||||
rsc.active_session = session
|
||||
|
||||
# Execute try_flush
|
||||
events = rsc.try_flush()
|
||||
|
||||
# Verify all segments were processed
|
||||
assert len(events) == 3
|
||||
|
||||
# Check events in order
|
||||
assert events[0].chunk == "Output 1"
|
||||
assert events[0].selector == ["node1", "output"]
|
||||
|
||||
assert events[1].chunk == " | "
|
||||
|
||||
assert events[2].chunk == "Output 2"
|
||||
assert events[2].selector == ["node2", "output"]
|
||||
|
||||
# Session should be complete
|
||||
assert session.is_complete()
|
||||
|
||||
def test_mixed_skipped_and_active_nodes(self):
|
||||
"""Test processing with a mix of skipped and active nodes."""
|
||||
# Create mock graph
|
||||
graph = Mock(spec=Graph)
|
||||
|
||||
# Create mock nodes with various states
|
||||
skipped_node1 = Mock(spec=Node)
|
||||
skipped_node1.id = "skip1"
|
||||
skipped_node1.state = NodeState.SKIPPED
|
||||
skipped_node1.node_type = NodeType.LLM
|
||||
|
||||
active_node = Mock(spec=Node)
|
||||
active_node.id = "active"
|
||||
active_node.state = NodeState.TAKEN
|
||||
active_node.node_type = NodeType.LLM
|
||||
|
||||
skipped_node2 = Mock(spec=Node)
|
||||
skipped_node2.id = "skip2"
|
||||
skipped_node2.state = NodeState.SKIPPED
|
||||
skipped_node2.node_type = NodeType.LLM
|
||||
|
||||
response_node = Mock(spec=AnswerNode)
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
|
||||
# Set up graph nodes dictionary
|
||||
graph.nodes = {
|
||||
"skip1": skipped_node1,
|
||||
"active": active_node,
|
||||
"skip2": skipped_node2,
|
||||
"response_node": response_node,
|
||||
}
|
||||
|
||||
# Create output registry with variable pool
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Add data only for active node
|
||||
registry.set_scalar(("active", "result"), StringSegment(value="Active Result"))
|
||||
|
||||
# Create RSC instance
|
||||
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)
|
||||
|
||||
# Create template with mixed segments
|
||||
template = Template(
|
||||
segments=[
|
||||
TextSegment(text="Start: "),
|
||||
VariableSegment(selector=["skip1", "output"]),
|
||||
VariableSegment(selector=["active", "result"]),
|
||||
VariableSegment(selector=["skip2", "output"]),
|
||||
TextSegment(text=" :End"),
|
||||
]
|
||||
)
|
||||
|
||||
# Create and set active session
|
||||
session = ResponseSession(node_id="response_node", template=template, index=0)
|
||||
rsc.active_session = session
|
||||
|
||||
# Execute try_flush
|
||||
events = rsc.try_flush()
|
||||
|
||||
# Should have: "Start: ", "Active Result", " :End"
|
||||
assert len(events) == 3
|
||||
|
||||
assert events[0].chunk == "Start: "
|
||||
assert events[1].chunk == "Active Result"
|
||||
assert events[1].selector == ["active", "result"]
|
||||
assert events[2].chunk == " :End"
|
||||
|
||||
# Session should be complete
|
||||
assert session.is_complete()
|
||||
|
||||
def test_all_variable_segments_skipped(self):
|
||||
"""Test when all VariableSegments are from skipped nodes."""
|
||||
# Create mock graph
|
||||
graph = Mock(spec=Graph)
|
||||
|
||||
# Create all skipped nodes
|
||||
skipped_node1 = Mock(spec=Node)
|
||||
skipped_node1.id = "skip1"
|
||||
skipped_node1.state = NodeState.SKIPPED
|
||||
skipped_node1.node_type = NodeType.LLM
|
||||
|
||||
skipped_node2 = Mock(spec=Node)
|
||||
skipped_node2.id = "skip2"
|
||||
skipped_node2.state = NodeState.SKIPPED
|
||||
skipped_node2.node_type = NodeType.LLM
|
||||
|
||||
response_node = Mock(spec=AnswerNode)
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
|
||||
# Set up graph nodes dictionary
|
||||
graph.nodes = {"skip1": skipped_node1, "skip2": skipped_node2, "response_node": response_node}
|
||||
|
||||
# Create output registry (empty since nodes are skipped) with variable pool
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
|
||||
# Create RSC instance
|
||||
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)
|
||||
|
||||
# Create template with only skipped segments
|
||||
template = Template(
|
||||
segments=[
|
||||
VariableSegment(selector=["skip1", "output"]),
|
||||
VariableSegment(selector=["skip2", "output"]),
|
||||
TextSegment(text="Final text"),
|
||||
]
|
||||
)
|
||||
|
||||
# Create and set active session
|
||||
session = ResponseSession(node_id="response_node", template=template, index=0)
|
||||
rsc.active_session = session
|
||||
|
||||
# Execute try_flush
|
||||
events = rsc.try_flush()
|
||||
|
||||
# Should only have the final text segment
|
||||
assert len(events) == 1
|
||||
assert events[0].chunk == "Final text"
|
||||
|
||||
# Session should be complete
|
||||
assert session.is_complete()
|
||||
|
||||
def test_special_prefix_selectors(self):
|
||||
"""Test that special prefix selectors (sys, env, conversation) are handled correctly."""
|
||||
# Create mock graph
|
||||
graph = Mock(spec=Graph)
|
||||
|
||||
# Create response node
|
||||
response_node = Mock(spec=AnswerNode)
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
|
||||
# Set up graph nodes dictionary (no sys, env, conversation nodes)
|
||||
graph.nodes = {"response_node": response_node}
|
||||
|
||||
# Create output registry with special selector data and variable pool
|
||||
variable_pool = VariablePool()
|
||||
registry = OutputRegistry(variable_pool)
|
||||
registry.set_scalar(("sys", "user_id"), StringSegment(value="user123"))
|
||||
registry.set_scalar(("env", "api_key"), StringSegment(value="key456"))
|
||||
registry.set_scalar(("conversation", "id"), StringSegment(value="conv789"))
|
||||
|
||||
# Create RSC instance
|
||||
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)
|
||||
|
||||
# Create template with special selectors
|
||||
template = Template(
|
||||
segments=[
|
||||
TextSegment(text="User: "),
|
||||
VariableSegment(selector=["sys", "user_id"]),
|
||||
TextSegment(text=", API: "),
|
||||
VariableSegment(selector=["env", "api_key"]),
|
||||
TextSegment(text=", Conv: "),
|
||||
VariableSegment(selector=["conversation", "id"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Create and set active session
|
||||
session = ResponseSession(node_id="response_node", template=template, index=0)
|
||||
rsc.active_session = session
|
||||
|
||||
# Execute try_flush
|
||||
events = rsc.try_flush()
|
||||
|
||||
# Should have all segments processed
|
||||
assert len(events) == 6
|
||||
|
||||
# Check text segments
|
||||
assert events[0].chunk == "User: "
|
||||
assert events[0].node_id == "response_node"
|
||||
|
||||
# Check sys selector - should use response node's info
|
||||
assert events[1].chunk == "user123"
|
||||
assert events[1].selector == ["sys", "user_id"]
|
||||
assert events[1].node_id == "response_node"
|
||||
assert events[1].node_type == NodeType.ANSWER
|
||||
|
||||
assert events[2].chunk == ", API: "
|
||||
|
||||
# Check env selector - should use response node's info
|
||||
assert events[3].chunk == "key456"
|
||||
assert events[3].selector == ["env", "api_key"]
|
||||
assert events[3].node_id == "response_node"
|
||||
assert events[3].node_type == NodeType.ANSWER
|
||||
|
||||
assert events[4].chunk == ", Conv: "
|
||||
|
||||
# Check conversation selector - should use response node's info
|
||||
assert events[5].chunk == "conv789"
|
||||
assert events[5].selector == ["conversation", "id"]
|
||||
assert events[5].node_id == "response_node"
|
||||
assert events[5].node_type == NodeType.ANSWER
|
||||
|
||||
# Session should be complete
|
||||
assert session.is_complete()
|
||||
Reference in New Issue
Block a user