mirror of
https://github.com/langgenius/dify.git
synced 2026-01-17 12:29:57 +00:00
Compare commits
1 Commits
feature/ta
...
feat/llm-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5db06175de |
@@ -109,6 +109,7 @@ class ModelInstance:
|
||||
stream: Literal[True] = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@overload
|
||||
@@ -121,6 +122,7 @@ class ModelInstance:
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
@@ -133,6 +135,7 @@ class ModelInstance:
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
def invoke_llm(
|
||||
@@ -144,6 +147,7 @@ class ModelInstance:
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -155,26 +159,33 @@ class ModelInstance:
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param first_token_timeout: timeout in seconds for receiving first token (streaming only)
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
|
||||
result = self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Apply first token timeout wrapper for streaming responses
|
||||
if stream and first_token_timeout and first_token_timeout > 0 and isinstance(result, Generator):
|
||||
from core.workflow.utils.generator_timeout import with_first_token_timeout
|
||||
|
||||
result = with_first_token_timeout(result, first_token_timeout)
|
||||
|
||||
return cast(Union[LLMResult, Generator], result)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None
|
||||
) -> int:
|
||||
|
||||
@@ -23,10 +23,22 @@ class RetryConfig(BaseModel):
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
# First token timeout for LLM nodes (milliseconds), 0 means no timeout
|
||||
first_token_timeout: int = 0
|
||||
|
||||
@property
|
||||
def first_token_timeout_seconds(self) -> float:
|
||||
return self.first_token_timeout / 1000
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
@property
|
||||
def has_first_token_timeout(self) -> bool:
|
||||
"""Check if first token timeout should be applied (retry enabled and timeout > 0)."""
|
||||
return self.retry_enabled and self.first_token_timeout > 0
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -43,3 +43,11 @@ class FileTypeNotSupportError(LLMNodeError):
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
|
||||
class LLMFirstTokenTimeoutError(LLMNodeError):
|
||||
"""Raised when LLM request fails to receive first token within configured timeout."""
|
||||
|
||||
def __init__(self, timeout_ms: int):
|
||||
self.timeout_ms = timeout_ms
|
||||
super().__init__(f"LLM request timed out after {timeout_ms}ms without receiving first token")
|
||||
|
||||
@@ -237,6 +237,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
# Get first token timeout from retry config if enabled (convert ms to seconds)
|
||||
first_token_timeout = (
|
||||
self.node_data.retry_config.first_token_timeout_seconds
|
||||
if self.node_data.retry_config.has_first_token_timeout
|
||||
else None
|
||||
)
|
||||
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
@@ -250,6 +257,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
first_token_timeout=first_token_timeout,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@@ -367,6 +375,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
first_token_timeout: float | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@@ -400,6 +409,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
first_token_timeout=first_token_timeout,
|
||||
)
|
||||
|
||||
return LLMNode.handle_invoke_result(
|
||||
|
||||
54
api/core/workflow/utils/generator_timeout.py
Normal file
54
api/core/workflow/utils/generator_timeout.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Generator timeout utilities for workflow nodes.
|
||||
|
||||
Provides timeout wrappers for streaming generators, primarily used for
|
||||
LLM response streaming where we need to enforce time-to-first-token limits.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class FirstTokenTimeoutError(Exception):
|
||||
"""Raised when a generator fails to yield its first item within the configured timeout."""
|
||||
|
||||
def __init__(self, timeout_ms: int):
|
||||
self.timeout_ms = timeout_ms
|
||||
super().__init__(f"Generator timed out after {timeout_ms}ms without yielding first item")
|
||||
|
||||
|
||||
def with_first_token_timeout(
|
||||
generator: Generator[T, None, None],
|
||||
timeout_seconds: float,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Wrap a generator with first token timeout monitoring.
|
||||
|
||||
Only monitors the time until the FIRST item is yielded.
|
||||
Once the first item arrives, timeout monitoring stops and
|
||||
subsequent items are yielded without timeout checks.
|
||||
|
||||
Args:
|
||||
generator: The source generator to wrap
|
||||
timeout_seconds: Maximum time to wait for first item (in seconds)
|
||||
|
||||
Yields:
|
||||
Items from the source generator
|
||||
|
||||
Raises:
|
||||
FirstTokenTimeoutError: If first item doesn't arrive within timeout
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
first_token_received = False
|
||||
|
||||
for item in generator:
|
||||
if not first_token_received:
|
||||
current_time = time.monotonic()
|
||||
if current_time - start_time > timeout_seconds:
|
||||
raise FirstTokenTimeoutError(int(timeout_seconds * 1000))
|
||||
first_token_received = True
|
||||
|
||||
yield item
|
||||
@@ -0,0 +1,416 @@
|
||||
"""Tests for LLM Node first token timeout retry functionality."""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
from core.workflow.nodes.llm.exc import LLMFirstTokenTimeoutError
|
||||
from core.workflow.utils.generator_timeout import FirstTokenTimeoutError, with_first_token_timeout
|
||||
|
||||
|
||||
class TestRetryConfigFirstTokenTimeout:
|
||||
"""Test cases for RetryConfig first token timeout fields."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that first token timeout fields have correct default values."""
|
||||
config = RetryConfig()
|
||||
|
||||
assert config.first_token_timeout == 0
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_has_first_token_timeout_when_retry_enabled_and_positive(self):
|
||||
"""Test has_first_token_timeout returns True when retry enabled with positive timeout."""
|
||||
config = RetryConfig(
|
||||
retry_enabled=True,
|
||||
first_token_timeout=3000, # 3000ms = 3s
|
||||
)
|
||||
|
||||
assert config.has_first_token_timeout is True
|
||||
assert config.first_token_timeout_seconds == 3.0
|
||||
|
||||
def test_has_first_token_timeout_when_retry_disabled(self):
|
||||
"""Test has_first_token_timeout returns False when retry is disabled."""
|
||||
config = RetryConfig(
|
||||
retry_enabled=False,
|
||||
first_token_timeout=60,
|
||||
)
|
||||
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_has_first_token_timeout_when_zero_timeout(self):
|
||||
"""Test has_first_token_timeout returns False when timeout is 0."""
|
||||
config = RetryConfig(
|
||||
retry_enabled=True,
|
||||
first_token_timeout=0,
|
||||
)
|
||||
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Test that existing workflows without first_token_timeout work correctly."""
|
||||
old_config_data = {
|
||||
"max_retries": 3,
|
||||
"retry_interval": 1000,
|
||||
"retry_enabled": True,
|
||||
}
|
||||
|
||||
config = RetryConfig.model_validate(old_config_data)
|
||||
|
||||
assert config.max_retries == 3
|
||||
assert config.retry_interval == 1000
|
||||
assert config.retry_enabled is True
|
||||
assert config.first_token_timeout == 0
|
||||
# has_first_token_timeout is False because timeout is 0
|
||||
assert config.has_first_token_timeout is False
|
||||
|
||||
def test_full_config_serialization(self):
|
||||
"""Test that full config can be serialized and deserialized."""
|
||||
config = RetryConfig(
|
||||
max_retries=5,
|
||||
retry_interval=2000,
|
||||
retry_enabled=True,
|
||||
first_token_timeout=120,
|
||||
)
|
||||
|
||||
config_dict = config.model_dump()
|
||||
restored_config = RetryConfig.model_validate(config_dict)
|
||||
|
||||
assert restored_config.max_retries == 5
|
||||
assert restored_config.retry_interval == 2000
|
||||
assert restored_config.retry_enabled is True
|
||||
assert restored_config.first_token_timeout == 120
|
||||
assert restored_config.has_first_token_timeout is True
|
||||
|
||||
|
||||
class TestLLMFirstTokenTimeoutError:
|
||||
"""Test cases for LLMFirstTokenTimeoutError exception."""
|
||||
|
||||
def test_error_message_format(self):
|
||||
"""Test that error message contains timeout value in milliseconds."""
|
||||
error = LLMFirstTokenTimeoutError(timeout_ms=3000)
|
||||
|
||||
assert "3000ms" in str(error)
|
||||
assert "first token" in str(error).lower()
|
||||
|
||||
def test_inherits_from_llm_node_error(self):
|
||||
"""Test that LLMFirstTokenTimeoutError inherits from LLMNodeError."""
|
||||
from core.workflow.nodes.llm.exc import LLMNodeError
|
||||
|
||||
error = LLMFirstTokenTimeoutError(timeout_ms=3000)
|
||||
|
||||
assert isinstance(error, LLMNodeError)
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
|
||||
class TestWithFirstTokenTimeout:
|
||||
"""Test cases for with_first_token_timeout function."""
|
||||
|
||||
@staticmethod
|
||||
def _create_mock_chunk(text: str = "test") -> LLMResultChunk:
|
||||
"""Helper to create a mock LLMResultChunk."""
|
||||
return LLMResultChunk(
|
||||
model="test-model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
),
|
||||
)
|
||||
|
||||
def test_first_token_arrives_within_timeout(self):
|
||||
"""Test that chunks are yielded normally when first token arrives in time."""
|
||||
|
||||
def mock_generator() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("Hello")
|
||||
yield self._create_mock_chunk(" world")
|
||||
|
||||
wrapped = with_first_token_timeout(mock_generator(), timeout_seconds=10)
|
||||
chunks = list(wrapped)
|
||||
|
||||
assert len(chunks) == 2
|
||||
|
||||
def test_first_token_timeout_raises_error(self, monkeypatch):
|
||||
"""Test that timeout error is raised when first token doesn't arrive in time."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First call: start_time = 0
|
||||
# Second call (when checking): current_time = 11 (exceeds 10 second timeout)
|
||||
if call_count == 1:
|
||||
return 0.0
|
||||
return 11.0
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def slow_generator() -> Generator[LLMResultChunk, None, None]:
|
||||
# This chunk arrives "after timeout"
|
||||
yield self._create_mock_chunk("Late token")
|
||||
|
||||
wrapped = with_first_token_timeout(slow_generator(), timeout_seconds=10)
|
||||
|
||||
with pytest.raises(FirstTokenTimeoutError) as exc_info:
|
||||
list(wrapped)
|
||||
|
||||
# Error message shows milliseconds (10 seconds = 10000ms)
|
||||
assert "10000ms" in str(exc_info.value)
|
||||
|
||||
def test_no_timeout_check_after_first_token(self, monkeypatch):
|
||||
"""Test that subsequent chunks are not subject to timeout after first token received."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return 0.0 # start_time
|
||||
elif call_count == 2:
|
||||
return 5.0 # first token arrives at 5s (within 10s timeout)
|
||||
else:
|
||||
# Subsequent calls simulate long delays for remaining chunks
|
||||
# These should NOT trigger timeout because first token already received
|
||||
return 100.0 + call_count
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def generator_with_slow_subsequent_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("First")
|
||||
yield self._create_mock_chunk("Second")
|
||||
yield self._create_mock_chunk("Third")
|
||||
|
||||
wrapped = with_first_token_timeout(
|
||||
generator_with_slow_subsequent_chunks(),
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
# Should not raise, even though "time" passes beyond timeout after first token
|
||||
chunks = list(wrapped)
|
||||
assert len(chunks) == 3
|
||||
|
||||
def test_empty_generator_no_error(self):
|
||||
"""Test that empty generator doesn't raise timeout error (no chunks to check)."""
|
||||
|
||||
def empty_generator() -> Generator[LLMResultChunk, None, None]:
|
||||
return
|
||||
yield # unreachable, but makes this a generator
|
||||
|
||||
wrapped = with_first_token_timeout(empty_generator(), timeout_seconds=10)
|
||||
chunks = list(wrapped)
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_exact_timeout_boundary(self, monkeypatch):
|
||||
"""Test behavior at exact timeout boundary (should not raise when equal)."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return 0.0
|
||||
# Exactly at boundary: current_time - start_time = 10, timeout_seconds = 10
|
||||
# Since we check > not >=, this should NOT raise
|
||||
return 10.0
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def generator() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("Token at boundary")
|
||||
|
||||
wrapped = with_first_token_timeout(generator(), timeout_seconds=10)
|
||||
|
||||
# Should not raise because 10 is not > 10
|
||||
chunks = list(wrapped)
|
||||
assert len(chunks) == 1
|
||||
|
||||
def test_just_over_timeout_boundary(self, monkeypatch):
|
||||
"""Test behavior just over timeout boundary (should raise)."""
|
||||
call_count = 0
|
||||
|
||||
def mock_monotonic():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return 0.0
|
||||
# Just over boundary
|
||||
return 10.001
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||
|
||||
def generator() -> Generator[LLMResultChunk, None, None]:
|
||||
yield self._create_mock_chunk("Late token")
|
||||
|
||||
wrapped = with_first_token_timeout(generator(), timeout_seconds=10)
|
||||
|
||||
with pytest.raises(FirstTokenTimeoutError):
|
||||
list(wrapped)
|
||||
|
||||
|
||||
class TestLLMNodeInvokeLLMWithTimeout:
|
||||
"""Test cases for LLMNode.invoke_llm with first_token_timeout parameter."""
|
||||
|
||||
def test_invoke_llm_without_timeout(self):
|
||||
"""Test invoke_llm works normally when first_token_timeout is None."""
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||
mock_handle.return_value = iter([])
|
||||
|
||||
# Mock model_instance.invoke_llm to return empty generator
|
||||
mock_model_instance = mock.MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = iter([])
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||
|
||||
mock_node_data_model = mock.MagicMock()
|
||||
mock_node_data_model.completion_params = {}
|
||||
|
||||
result = LLMNode.invoke_llm(
|
||||
node_data_model=mock_node_data_model,
|
||||
model_instance=mock_model_instance,
|
||||
prompt_messages=[],
|
||||
user_id="test-user",
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=mock.MagicMock(),
|
||||
file_outputs=[],
|
||||
node_id="test-node",
|
||||
node_type=mock.MagicMock(),
|
||||
first_token_timeout=None, # No timeout
|
||||
)
|
||||
|
||||
list(result) # Consume generator
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_invoke_llm_with_timeout_passes_to_model_instance(self):
|
||||
"""Test invoke_llm passes first_token_timeout to model_instance.invoke_llm."""
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||
mock_handle.return_value = iter([])
|
||||
|
||||
mock_model_instance = mock.MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = iter([])
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||
|
||||
mock_node_data_model = mock.MagicMock()
|
||||
mock_node_data_model.completion_params = {}
|
||||
|
||||
result = LLMNode.invoke_llm(
|
||||
node_data_model=mock_node_data_model,
|
||||
model_instance=mock_model_instance,
|
||||
prompt_messages=[],
|
||||
user_id="test-user",
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=mock.MagicMock(),
|
||||
file_outputs=[],
|
||||
node_id="test-node",
|
||||
node_type=mock.MagicMock(),
|
||||
first_token_timeout=60, # With timeout
|
||||
)
|
||||
|
||||
list(result) # Consume generator
|
||||
|
||||
# Verify model_instance.invoke_llm was called with first_token_timeout
|
||||
mock_model_instance.invoke_llm.assert_called_once()
|
||||
call_kwargs = mock_model_instance.invoke_llm.call_args.kwargs
|
||||
assert call_kwargs.get("first_token_timeout") == 60
|
||||
|
||||
def test_invoke_llm_with_zero_timeout_passes_zero(self):
|
||||
"""Test invoke_llm passes zero timeout to model_instance."""
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||
mock_handle.return_value = iter([])
|
||||
|
||||
mock_model_instance = mock.MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = iter([])
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||
|
||||
mock_node_data_model = mock.MagicMock()
|
||||
mock_node_data_model.completion_params = {}
|
||||
|
||||
result = LLMNode.invoke_llm(
|
||||
node_data_model=mock_node_data_model,
|
||||
model_instance=mock_model_instance,
|
||||
prompt_messages=[],
|
||||
user_id="test-user",
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=mock.MagicMock(),
|
||||
file_outputs=[],
|
||||
node_id="test-node",
|
||||
node_type=mock.MagicMock(),
|
||||
first_token_timeout=0, # Zero timeout
|
||||
)
|
||||
|
||||
list(result) # Consume generator
|
||||
|
||||
# Verify model_instance.invoke_llm was called with zero timeout
|
||||
mock_model_instance.invoke_llm.assert_called_once()
|
||||
call_kwargs = mock_model_instance.invoke_llm.call_args.kwargs
|
||||
assert call_kwargs.get("first_token_timeout") == 0
|
||||
|
||||
|
||||
class TestRetryConfigIntegration:
|
||||
"""Integration tests for RetryConfig with LLM node data."""
|
||||
|
||||
def test_retry_config_in_node_data(self):
|
||||
"""Test RetryConfig can be properly configured in LLMNodeData."""
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode=LLMMode.CHAT,
|
||||
completion_params={},
|
||||
),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
structured_output_enabled=False,
|
||||
retry_config=RetryConfig(
|
||||
max_retries=3,
|
||||
retry_interval=1000,
|
||||
retry_enabled=True,
|
||||
first_token_timeout=3000, # 3000ms = 3s
|
||||
),
|
||||
)
|
||||
|
||||
assert node_data.retry_config.max_retries == 3
|
||||
assert node_data.retry_config.retry_enabled is True
|
||||
assert node_data.retry_config.first_token_timeout == 3000
|
||||
assert node_data.retry_config.first_token_timeout_seconds == 3.0
|
||||
assert node_data.retry_config.has_first_token_timeout is True
|
||||
|
||||
def test_default_retry_config_in_node_data(self):
|
||||
"""Test default RetryConfig in LLMNodeData."""
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode=LLMMode.CHAT,
|
||||
completion_params={},
|
||||
),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
|
||||
# Should have default RetryConfig
|
||||
assert node_data.retry_config.max_retries == 0
|
||||
assert node_data.retry_config.retry_enabled is False
|
||||
assert node_data.retry_config.first_token_timeout == 0
|
||||
assert node_data.retry_config.has_first_token_timeout is False
|
||||
@@ -9,6 +9,10 @@ import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
import { useRetryConfig } from './hooks'
|
||||
import s from './style.module.css'
|
||||
|
||||
// Nodes that support first token timeout configuration
|
||||
// These nodes internally call LLM and have streaming response characteristics
|
||||
const LLM_RELATED_NODE_TYPES = ['llm', 'agent', 'parameter-extractor', 'question-classifier']
|
||||
|
||||
type RetryOnPanelProps = Pick<Node, 'id' | 'data'>
|
||||
const RetryOnPanel = ({
|
||||
id,
|
||||
@@ -16,10 +20,14 @@ const RetryOnPanel = ({
|
||||
}: RetryOnPanelProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { handleRetryConfigChange } = useRetryConfig(id)
|
||||
const { retry_config } = data
|
||||
const { retry_config, type } = data
|
||||
|
||||
// Check if this is an LLM-related node that supports first token timeout
|
||||
const isLLMRelatedNode = LLM_RELATED_NODE_TYPES.includes(type)
|
||||
|
||||
const handleRetryEnabledChange = (value: boolean) => {
|
||||
handleRetryConfigChange({
|
||||
...retry_config,
|
||||
retry_enabled: value,
|
||||
max_retries: retry_config?.max_retries || 3,
|
||||
retry_interval: retry_config?.retry_interval || 1000,
|
||||
@@ -32,6 +40,7 @@ const RetryOnPanel = ({
|
||||
else if (value < 1)
|
||||
value = 1
|
||||
handleRetryConfigChange({
|
||||
...retry_config,
|
||||
retry_enabled: true,
|
||||
max_retries: value,
|
||||
retry_interval: retry_config?.retry_interval || 1000,
|
||||
@@ -44,12 +53,27 @@ const RetryOnPanel = ({
|
||||
else if (value < 100)
|
||||
value = 100
|
||||
handleRetryConfigChange({
|
||||
...retry_config,
|
||||
retry_enabled: true,
|
||||
max_retries: retry_config?.max_retries || 3,
|
||||
retry_interval: value,
|
||||
})
|
||||
}
|
||||
|
||||
const handleFirstTokenTimeoutChange = (value: number) => {
|
||||
if (value > 60000)
|
||||
value = 60000
|
||||
else if (value < 0)
|
||||
value = 0
|
||||
handleRetryConfigChange({
|
||||
...retry_config,
|
||||
retry_enabled: true,
|
||||
max_retries: retry_config?.max_retries || 3,
|
||||
retry_interval: retry_config?.retry_interval || 1000,
|
||||
first_token_timeout: value,
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="pt-2">
|
||||
@@ -62,54 +86,76 @@ const RetryOnPanel = ({
|
||||
onChange={v => handleRetryEnabledChange(v)}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
retry_config?.retry_enabled && (
|
||||
<div className="px-4 pb-2">
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.maxRetries', { ns: 'workflow' })}</div>
|
||||
{retry_config?.retry_enabled && (
|
||||
<div className="px-4 pb-2">
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.maxRetries', { ns: 'workflow' })}</div>
|
||||
<Slider
|
||||
className="mr-3 w-[108px]"
|
||||
value={retry_config?.max_retries || 3}
|
||||
onChange={handleMaxRetriesChange}
|
||||
min={1}
|
||||
max={10}
|
||||
/>
|
||||
<Input
|
||||
type="number"
|
||||
wrapperClassName="w-[100px]"
|
||||
value={retry_config?.max_retries || 3}
|
||||
onChange={e =>
|
||||
handleMaxRetriesChange(Number.parseInt(e.currentTarget.value, 10) || 3)}
|
||||
min={1}
|
||||
max={10}
|
||||
unit={t('nodes.common.retry.times', { ns: 'workflow' }) || ''}
|
||||
className={s.input}
|
||||
/>
|
||||
</div>
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.retryInterval', { ns: 'workflow' })}</div>
|
||||
<Slider
|
||||
className="mr-3 w-[108px]"
|
||||
value={retry_config?.retry_interval || 1000}
|
||||
onChange={handleRetryIntervalChange}
|
||||
min={100}
|
||||
max={5000}
|
||||
/>
|
||||
<Input
|
||||
type="number"
|
||||
wrapperClassName="w-[100px]"
|
||||
value={retry_config?.retry_interval || 1000}
|
||||
onChange={e =>
|
||||
handleRetryIntervalChange(Number.parseInt(e.currentTarget.value, 10) || 1000)}
|
||||
min={100}
|
||||
max={5000}
|
||||
unit={t('nodes.common.retry.ms', { ns: 'workflow' }) || ''}
|
||||
className={s.input}
|
||||
/>
|
||||
</div>
|
||||
{/* First token timeout - only for LLM-related nodes */}
|
||||
{isLLMRelatedNode && (
|
||||
<div className="flex w-full items-center">
|
||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.firstTokenTimeout', { ns: 'workflow' })}</div>
|
||||
<Slider
|
||||
className="mr-3 w-[108px]"
|
||||
value={retry_config?.max_retries || 3}
|
||||
onChange={handleMaxRetriesChange}
|
||||
min={1}
|
||||
max={10}
|
||||
value={retry_config?.first_token_timeout ?? 3000}
|
||||
onChange={handleFirstTokenTimeoutChange}
|
||||
min={0}
|
||||
max={60000}
|
||||
/>
|
||||
<Input
|
||||
type="number"
|
||||
wrapperClassName="w-[100px]"
|
||||
value={retry_config?.max_retries || 3}
|
||||
value={retry_config?.first_token_timeout ?? 3000}
|
||||
onChange={e =>
|
||||
handleMaxRetriesChange(Number.parseInt(e.currentTarget.value, 10) || 3)}
|
||||
min={1}
|
||||
max={10}
|
||||
unit={t('nodes.common.retry.times', { ns: 'workflow' }) || ''}
|
||||
className={s.input}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.retryInterval', { ns: 'workflow' })}</div>
|
||||
<Slider
|
||||
className="mr-3 w-[108px]"
|
||||
value={retry_config?.retry_interval || 1000}
|
||||
onChange={handleRetryIntervalChange}
|
||||
min={100}
|
||||
max={5000}
|
||||
/>
|
||||
<Input
|
||||
type="number"
|
||||
wrapperClassName="w-[100px]"
|
||||
value={retry_config?.retry_interval || 1000}
|
||||
onChange={e =>
|
||||
handleRetryIntervalChange(Number.parseInt(e.currentTarget.value, 10) || 1000)}
|
||||
min={100}
|
||||
max={5000}
|
||||
handleFirstTokenTimeoutChange(Number.parseInt(e.currentTarget.value, 10) || 0)}
|
||||
min={0}
|
||||
max={60000}
|
||||
unit={t('nodes.common.retry.ms', { ns: 'workflow' }) || ''}
|
||||
className={s.input}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Split className="mx-4 mt-2" />
|
||||
</>
|
||||
|
||||
@@ -2,4 +2,6 @@ export type WorkflowRetryConfig = {
|
||||
max_retries: number
|
||||
retry_interval: number
|
||||
retry_enabled: boolean
|
||||
// First token timeout for LLM nodes (seconds), 0 means no timeout
|
||||
first_token_timeout?: number
|
||||
}
|
||||
|
||||
@@ -433,6 +433,7 @@
|
||||
"nodes.common.memory.windowSize": "Window Size",
|
||||
"nodes.common.outputVars": "Output Variables",
|
||||
"nodes.common.pluginNotInstalled": "Plugin is not installed",
|
||||
"nodes.common.retry.firstTokenTimeout": "First token timeout",
|
||||
"nodes.common.retry.maxRetries": "max retries",
|
||||
"nodes.common.retry.ms": "ms",
|
||||
"nodes.common.retry.retries": "{{num}} Retries",
|
||||
@@ -444,6 +445,8 @@
|
||||
"nodes.common.retry.retrySuccessful": "Retry successful",
|
||||
"nodes.common.retry.retryTimes": "Retry {{times}} times on failure",
|
||||
"nodes.common.retry.retrying": "Retrying...",
|
||||
"nodes.common.retry.seconds": "s",
|
||||
"nodes.common.retry.timeoutDuration": "Timeout duration",
|
||||
"nodes.common.retry.times": "times",
|
||||
"nodes.common.typeSwitch.input": "Input value",
|
||||
"nodes.common.typeSwitch.variable": "Use variable",
|
||||
|
||||
@@ -433,6 +433,7 @@
|
||||
"nodes.common.memory.windowSize": "记忆窗口",
|
||||
"nodes.common.outputVars": "输出变量",
|
||||
"nodes.common.pluginNotInstalled": "插件未安装",
|
||||
"nodes.common.retry.firstTokenTimeout": "首个 Token 超时",
|
||||
"nodes.common.retry.maxRetries": "最大重试次数",
|
||||
"nodes.common.retry.ms": "毫秒",
|
||||
"nodes.common.retry.retries": "{{num}} 重试次数",
|
||||
@@ -444,6 +445,8 @@
|
||||
"nodes.common.retry.retrySuccessful": "重试成功",
|
||||
"nodes.common.retry.retryTimes": "失败时重试 {{times}} 次",
|
||||
"nodes.common.retry.retrying": "重试中...",
|
||||
"nodes.common.retry.seconds": "秒",
|
||||
"nodes.common.retry.timeoutDuration": "超时时长",
|
||||
"nodes.common.retry.times": "次",
|
||||
"nodes.common.typeSwitch.input": "输入值",
|
||||
"nodes.common.typeSwitch.variable": "使用变量",
|
||||
|
||||
Reference in New Issue
Block a user