mirror of
https://github.com/langgenius/dify.git
synced 2026-03-12 02:27:05 +00:00
Compare commits
4 Commits
copilot/re
...
yanli/iter
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
597b749399 | ||
|
|
b8d3e8e5c2 | ||
|
|
f68cf971a5 | ||
|
|
0fc5c23a1c |
@@ -502,7 +502,7 @@ class WorkflowResponseConverter:
|
||||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
|
||||
start_at = snapshot.start_at if snapshot else event.start_at
|
||||
finished_at = naive_utc_now()
|
||||
finished_at = event.finished_at or naive_utc_now()
|
||||
elapsed_time = (finished_at - start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
|
||||
@@ -436,6 +436,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@@ -451,6 +452,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
@@ -467,6 +469,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
|
||||
@@ -336,6 +336,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
@@ -391,6 +392,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
@@ -415,6 +417,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
@@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
@@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
actual_finished_at = finished_at or naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
domain_execution.finished_at = actual_finished_at
|
||||
domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
@@ -159,6 +159,7 @@ class ErrorHandler:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
@@ -198,6 +199,7 @@ class ErrorHandler:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
|
||||
@@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
|
||||
from typing_extensions import override
|
||||
|
||||
from dify_graph.context import IExecutionContext
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
@@ -65,6 +68,7 @@ class Worker(threading.Thread):
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
self._current_node_started_at: datetime | None = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
@@ -104,18 +108,15 @@ class Worker(threading.Thread):
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._current_node_started_at = None
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
self._event_queue.put(
|
||||
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
finally:
|
||||
self._current_node_started_at = None
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
@@ -136,6 +137,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
@@ -149,6 +152,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
@@ -177,3 +182,24 @@ class Worker(threading.Thread):
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _build_fallback_failure_event(
|
||||
self, node: Node, error: Exception, *, started_at: datetime | None = None
|
||||
) -> NodeRunFailedEvent:
|
||||
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
|
||||
failure_time = naive_utc_now()
|
||||
error_message = str(error)
|
||||
return NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=error_message,
|
||||
start_at=started_at or failure_time,
|
||||
finished_at=failure_time,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
error_type=type(error).__name__,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -37,16 +37,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
|
||||
@@ -428,11 +428,13 @@ class Node(Generic[NodeDataT]):
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
finished_at = naive_utc_now()
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
@@ -601,6 +603,7 @@ class Node(Generic[NodeDataT]):
|
||||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
finished_at = naive_utc_now()
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
@@ -608,6 +611,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
@@ -617,6 +621,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
@@ -639,6 +644,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
finished_at = naive_utc_now()
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
@@ -646,6 +652,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
@@ -654,6 +661,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
|
||||
@@ -235,7 +235,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
future_to_index: dict[
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
float,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, Variable],
|
||||
@@ -260,7 +260,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
try:
|
||||
result = future.result()
|
||||
(
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
@@ -273,8 +273,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
# The worker computes duration before we replay buffered events here,
|
||||
# so slow downstream consumers don't inflate per-iteration timing.
|
||||
iter_run_map[str(index)] = iteration_duration
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
@@ -304,7 +305,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
index: int,
|
||||
item: object,
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@@ -326,9 +327,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
return (
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
|
||||
@@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_prefers_event_finished_at(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Finished timestamps should come from the event, not delayed queue processing time."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.common.workflow_response_converter.naive_utc_now",
|
||||
lambda: delayed_processing_time,
|
||||
)
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id="node-exec-1",
|
||||
start_at=start_at,
|
||||
finished_at=finished_at,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.elapsed_time == 2.0
|
||||
assert response.data.finished_at == int(finished_at.timestamp())
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.workflow.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
_NodeRuntimeSnapshot,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_layer() -> WorkflowPersistenceLayer:
|
||||
application_generate_entity = Mock()
|
||||
application_generate_entity.inputs = {}
|
||||
|
||||
return WorkflowPersistenceLayer(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_data={},
|
||||
),
|
||||
workflow_execution_repository=Mock(),
|
||||
workflow_node_execution_repository=Mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
layer = _build_layer()
|
||||
node_execution = Mock()
|
||||
node_execution.id = "node-exec-1"
|
||||
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
node_execution.update_from_mapping = Mock()
|
||||
|
||||
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
|
||||
node_id="node-id",
|
||||
title="LLM",
|
||||
predecessor_node_id=None,
|
||||
iteration_id="iter-1",
|
||||
loop_id=None,
|
||||
created_at=node_execution.created_at,
|
||||
)
|
||||
|
||||
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
|
||||
|
||||
layer._update_node_execution(
|
||||
node_execution,
|
||||
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event_finished_at,
|
||||
)
|
||||
|
||||
assert node_execution.finished_at == event_finished_at
|
||||
assert node_execution.elapsed_time == 2.0
|
||||
@@ -0,0 +1,87 @@
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from dify_graph.graph_engine.worker import Worker
|
||||
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
|
||||
|
||||
|
||||
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
|
||||
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=InMemoryReadyQueue(),
|
||||
event_queue=queue.Queue(),
|
||||
graph=MagicMock(),
|
||||
layers=[],
|
||||
)
|
||||
node = SimpleNamespace(
|
||||
execution_id="exec-1",
|
||||
id="node-1",
|
||||
node_type=NodeType.LLM,
|
||||
)
|
||||
|
||||
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
|
||||
|
||||
assert event.start_at == fixed_time
|
||||
assert event.finished_at == fixed_time
|
||||
assert event.error == "boom"
|
||||
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert event.node_run_result.error == "boom"
|
||||
assert event.node_run_result.error_type == "RuntimeError"
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
|
||||
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
failure_time = start_at + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeNode:
|
||||
execution_id = "exec-1"
|
||||
id = "node-1"
|
||||
node_type = NodeType.LLM
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="LLM",
|
||||
start_at=start_at,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"node-1": FakeNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["node-1"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 1:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == start_at
|
||||
assert fallback_event.finished_at == failure_time
|
||||
assert fallback_event.error == "queue boom"
|
||||
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
@@ -0,0 +1,63 @@
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_events import NodeRunSucceededEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from dify_graph.nodes.iteration.iteration_node import IterationNode
|
||||
|
||||
|
||||
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Parallel Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iteration", "output"],
|
||||
is_parallel=True,
|
||||
parallel_nums=2,
|
||||
error_handle_mode=ErrorHandleMode.TERMINATED,
|
||||
)
|
||||
node._capture_execution_context = lambda: nullcontext()
|
||||
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
|
||||
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
|
||||
|
||||
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
|
||||
return (
|
||||
0.1 + (index * 0.1),
|
||||
[
|
||||
NodeRunSucceededEvent(
|
||||
id=f"exec-{index}",
|
||||
node_id=f"llm-{index}",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
),
|
||||
],
|
||||
f"output-{item}",
|
||||
{},
|
||||
LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
|
||||
|
||||
outputs: list[object] = []
|
||||
iter_run_map: dict[str, float] = {}
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
generator = node._execute_parallel_iterations(
|
||||
iterator_list_value=["a", "b"],
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
for _ in generator:
|
||||
# Simulate a slow consumer replaying buffered events.
|
||||
time.sleep(0.02)
|
||||
|
||||
assert outputs == ["output-a", "output-b"]
|
||||
assert iter_run_map["0"] == pytest.approx(0.1)
|
||||
assert iter_run_map["1"] == pytest.approx(0.2)
|
||||
Reference in New Issue
Block a user