mirror of
https://github.com/langgenius/dify.git
synced 2026-03-06 15:45:14 +00:00
Compare commits
2 Commits
refactor/b
...
feat/llm-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfc1583626 | ||
|
|
5db06175de |
@@ -21,6 +21,7 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
|
|||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
from core.workflow.utils.generator_timeout import with_first_token_timeout
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||||
@@ -109,6 +110,7 @@ class ModelInstance:
|
|||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
callbacks: list[Callback] | None = None,
|
callbacks: list[Callback] | None = None,
|
||||||
|
first_token_timeout: float | None = None,
|
||||||
) -> Generator: ...
|
) -> Generator: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -121,6 +123,7 @@ class ModelInstance:
|
|||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
callbacks: list[Callback] | None = None,
|
callbacks: list[Callback] | None = None,
|
||||||
|
first_token_timeout: float | None = None,
|
||||||
) -> LLMResult: ...
|
) -> LLMResult: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -133,6 +136,7 @@ class ModelInstance:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
callbacks: list[Callback] | None = None,
|
callbacks: list[Callback] | None = None,
|
||||||
|
first_token_timeout: float | None = None,
|
||||||
) -> Union[LLMResult, Generator]: ...
|
) -> Union[LLMResult, Generator]: ...
|
||||||
|
|
||||||
def invoke_llm(
|
def invoke_llm(
|
||||||
@@ -144,6 +148,7 @@ class ModelInstance:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
callbacks: list[Callback] | None = None,
|
callbacks: list[Callback] | None = None,
|
||||||
|
first_token_timeout: float | None = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@@ -155,26 +160,31 @@ class ModelInstance:
|
|||||||
:param stream: is stream response
|
:param stream: is stream response
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:param callbacks: callbacks
|
:param callbacks: callbacks
|
||||||
|
:param first_token_timeout: timeout in seconds for receiving first token (streaming only)
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||||
raise Exception("Model type instance is not LargeLanguageModel")
|
raise Exception("Model type instance is not LargeLanguageModel")
|
||||||
return cast(
|
|
||||||
Union[LLMResult, Generator],
|
result = self._round_robin_invoke(
|
||||||
self._round_robin_invoke(
|
function=self.model_type_instance.invoke,
|
||||||
function=self.model_type_instance.invoke,
|
model=self.model,
|
||||||
model=self.model,
|
credentials=self.credentials,
|
||||||
credentials=self.credentials,
|
prompt_messages=prompt_messages,
|
||||||
prompt_messages=prompt_messages,
|
model_parameters=model_parameters,
|
||||||
model_parameters=model_parameters,
|
tools=tools,
|
||||||
tools=tools,
|
stop=stop,
|
||||||
stop=stop,
|
stream=stream,
|
||||||
stream=stream,
|
user=user,
|
||||||
user=user,
|
callbacks=callbacks,
|
||||||
callbacks=callbacks,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply first token timeout wrapper for streaming responses
|
||||||
|
if stream and first_token_timeout and first_token_timeout > 0 and isinstance(result, Generator):
|
||||||
|
result = with_first_token_timeout(result, first_token_timeout)
|
||||||
|
|
||||||
|
return cast(Union[LLMResult, Generator], result)
|
||||||
|
|
||||||
def get_llm_num_tokens(
|
def get_llm_num_tokens(
|
||||||
self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None
|
self, prompt_messages: Sequence[PromptMessage], tools: Sequence[PromptMessageTool] | None = None
|
||||||
) -> int:
|
) -> int:
|
||||||
|
|||||||
@@ -23,10 +23,22 @@ class RetryConfig(BaseModel):
|
|||||||
retry_interval: int = 0 # retry interval in milliseconds
|
retry_interval: int = 0 # retry interval in milliseconds
|
||||||
retry_enabled: bool = False # whether retry is enabled
|
retry_enabled: bool = False # whether retry is enabled
|
||||||
|
|
||||||
|
# First token timeout for LLM nodes (milliseconds), 0 means no timeout
|
||||||
|
first_token_timeout: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first_token_timeout_seconds(self) -> float:
|
||||||
|
return self.first_token_timeout / 1000
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry_interval_seconds(self) -> float:
|
def retry_interval_seconds(self) -> float:
|
||||||
return self.retry_interval / 1000
|
return self.retry_interval / 1000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_first_token_timeout(self) -> bool:
|
||||||
|
"""Check if first token timeout should be applied (retry enabled and timeout > 0)."""
|
||||||
|
return self.retry_enabled and self.first_token_timeout > 0
|
||||||
|
|
||||||
|
|
||||||
class VariableSelector(BaseModel):
|
class VariableSelector(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -237,6 +237,13 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
# Get first token timeout from retry config if enabled (convert ms to seconds)
|
||||||
|
first_token_timeout = (
|
||||||
|
self.node_data.retry_config.first_token_timeout_seconds
|
||||||
|
if self.node_data.retry_config.has_first_token_timeout
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
generator = LLMNode.invoke_llm(
|
generator = LLMNode.invoke_llm(
|
||||||
node_data_model=self.node_data.model,
|
node_data_model=self.node_data.model,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
@@ -250,6 +257,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
reasoning_format=self.node_data.reasoning_format,
|
reasoning_format=self.node_data.reasoning_format,
|
||||||
|
first_token_timeout=first_token_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
structured_output: LLMStructuredOutput | None = None
|
structured_output: LLMStructuredOutput | None = None
|
||||||
@@ -367,6 +375,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
|
first_token_timeout: float | None = None,
|
||||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||||
node_data_model.name, model_instance.credentials
|
node_data_model.name, model_instance.credentials
|
||||||
@@ -400,6 +409,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
stop=list(stop or []),
|
stop=list(stop or []),
|
||||||
stream=True,
|
stream=True,
|
||||||
user=user_id,
|
user=user_id,
|
||||||
|
first_token_timeout=first_token_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LLMNode.handle_invoke_result(
|
return LLMNode.handle_invoke_result(
|
||||||
|
|||||||
56
api/core/workflow/utils/generator_timeout.py
Normal file
56
api/core/workflow/utils/generator_timeout.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
Generator timeout utilities for workflow nodes.
|
||||||
|
|
||||||
|
Provides timeout wrappers for streaming generators, primarily used for
|
||||||
|
LLM response streaming where we need to enforce time-to-first-token limits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class FirstTokenTimeoutError(Exception):
|
||||||
|
"""Raised when a generator fails to yield its first item within the configured timeout."""
|
||||||
|
|
||||||
|
def __init__(self, timeout_ms: int):
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
super().__init__(f"Generator timed out after {timeout_ms}ms without yielding first item")
|
||||||
|
|
||||||
|
|
||||||
|
def with_first_token_timeout(
|
||||||
|
generator: Generator[T, None, None],
|
||||||
|
timeout_seconds: float,
|
||||||
|
) -> Generator[T, None, None]:
|
||||||
|
"""
|
||||||
|
Wrap a generator with first token timeout monitoring.
|
||||||
|
|
||||||
|
Only monitors the time until the FIRST item is yielded.
|
||||||
|
Once the first item arrives, timeout monitoring stops and
|
||||||
|
subsequent items are yielded without timeout checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: The source generator to wrap
|
||||||
|
timeout_seconds: Maximum time to wait for first item (in seconds)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Items from the source generator
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FirstTokenTimeoutError: If first item doesn't arrive within timeout
|
||||||
|
"""
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
# Handle first item separately to check timeout only once
|
||||||
|
try:
|
||||||
|
first_item = next(generator)
|
||||||
|
if time.monotonic() - start_time > timeout_seconds:
|
||||||
|
raise FirstTokenTimeoutError(int(timeout_seconds * 1000))
|
||||||
|
yield first_item
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Yield remaining items without timeout checks
|
||||||
|
yield from generator
|
||||||
@@ -0,0 +1,395 @@
|
|||||||
|
"""Tests for LLM Node first token timeout retry functionality."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||||
|
from core.workflow.nodes.base.entities import RetryConfig
|
||||||
|
from core.workflow.utils.generator_timeout import FirstTokenTimeoutError, with_first_token_timeout
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryConfigFirstTokenTimeout:
|
||||||
|
"""Test cases for RetryConfig first token timeout fields."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test that first token timeout fields have correct default values."""
|
||||||
|
config = RetryConfig()
|
||||||
|
|
||||||
|
assert config.first_token_timeout == 0
|
||||||
|
assert config.has_first_token_timeout is False
|
||||||
|
|
||||||
|
def test_has_first_token_timeout_when_retry_enabled_and_positive(self):
|
||||||
|
"""Test has_first_token_timeout returns True when retry enabled with positive timeout."""
|
||||||
|
config = RetryConfig(
|
||||||
|
retry_enabled=True,
|
||||||
|
first_token_timeout=3000, # 3000ms = 3s
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.has_first_token_timeout is True
|
||||||
|
assert config.first_token_timeout_seconds == 3.0
|
||||||
|
|
||||||
|
def test_has_first_token_timeout_when_retry_disabled(self):
|
||||||
|
"""Test has_first_token_timeout returns False when retry is disabled."""
|
||||||
|
config = RetryConfig(
|
||||||
|
retry_enabled=False,
|
||||||
|
first_token_timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.has_first_token_timeout is False
|
||||||
|
|
||||||
|
def test_has_first_token_timeout_when_zero_timeout(self):
|
||||||
|
"""Test has_first_token_timeout returns False when timeout is 0."""
|
||||||
|
config = RetryConfig(
|
||||||
|
retry_enabled=True,
|
||||||
|
first_token_timeout=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.has_first_token_timeout is False
|
||||||
|
|
||||||
|
def test_backward_compatibility(self):
|
||||||
|
"""Test that existing workflows without first_token_timeout work correctly."""
|
||||||
|
old_config_data = {
|
||||||
|
"max_retries": 3,
|
||||||
|
"retry_interval": 1000,
|
||||||
|
"retry_enabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = RetryConfig.model_validate(old_config_data)
|
||||||
|
|
||||||
|
assert config.max_retries == 3
|
||||||
|
assert config.retry_interval == 1000
|
||||||
|
assert config.retry_enabled is True
|
||||||
|
assert config.first_token_timeout == 0
|
||||||
|
# has_first_token_timeout is False because timeout is 0
|
||||||
|
assert config.has_first_token_timeout is False
|
||||||
|
|
||||||
|
def test_full_config_serialization(self):
|
||||||
|
"""Test that full config can be serialized and deserialized."""
|
||||||
|
config = RetryConfig(
|
||||||
|
max_retries=5,
|
||||||
|
retry_interval=2000,
|
||||||
|
retry_enabled=True,
|
||||||
|
first_token_timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict = config.model_dump()
|
||||||
|
restored_config = RetryConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
assert restored_config.max_retries == 5
|
||||||
|
assert restored_config.retry_interval == 2000
|
||||||
|
assert restored_config.retry_enabled is True
|
||||||
|
assert restored_config.first_token_timeout == 120
|
||||||
|
assert restored_config.has_first_token_timeout is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithFirstTokenTimeout:
|
||||||
|
"""Test cases for with_first_token_timeout function."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_mock_chunk(text: str = "test") -> LLMResultChunk:
|
||||||
|
"""Helper to create a mock LLMResultChunk."""
|
||||||
|
return LLMResultChunk(
|
||||||
|
model="test-model",
|
||||||
|
prompt_messages=[],
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=text),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_first_token_arrives_within_timeout(self):
|
||||||
|
"""Test that chunks are yielded normally when first token arrives in time."""
|
||||||
|
|
||||||
|
def mock_generator() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
yield self._create_mock_chunk("Hello")
|
||||||
|
yield self._create_mock_chunk(" world")
|
||||||
|
|
||||||
|
wrapped = with_first_token_timeout(mock_generator(), timeout_seconds=10)
|
||||||
|
chunks = list(wrapped)
|
||||||
|
|
||||||
|
assert len(chunks) == 2
|
||||||
|
|
||||||
|
def test_first_token_timeout_raises_error(self, monkeypatch):
|
||||||
|
"""Test that timeout error is raised when first token doesn't arrive in time."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_monotonic():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
# First call: start_time = 0
|
||||||
|
# Second call (when checking): current_time = 11 (exceeds 10 second timeout)
|
||||||
|
if call_count == 1:
|
||||||
|
return 0.0
|
||||||
|
return 11.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||||
|
|
||||||
|
def slow_generator() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
# This chunk arrives "after timeout"
|
||||||
|
yield self._create_mock_chunk("Late token")
|
||||||
|
|
||||||
|
wrapped = with_first_token_timeout(slow_generator(), timeout_seconds=10)
|
||||||
|
|
||||||
|
with pytest.raises(FirstTokenTimeoutError) as exc_info:
|
||||||
|
list(wrapped)
|
||||||
|
|
||||||
|
# Error message shows milliseconds (10 seconds = 10000ms)
|
||||||
|
assert "10000ms" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_no_timeout_check_after_first_token(self, monkeypatch):
|
||||||
|
"""Test that subsequent chunks are not subject to timeout after first token received."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_monotonic():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return 0.0 # start_time
|
||||||
|
elif call_count == 2:
|
||||||
|
return 5.0 # first token arrives at 5s (within 10s timeout)
|
||||||
|
else:
|
||||||
|
# Subsequent calls simulate long delays for remaining chunks
|
||||||
|
# These should NOT trigger timeout because first token already received
|
||||||
|
return 100.0 + call_count
|
||||||
|
|
||||||
|
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||||
|
|
||||||
|
def generator_with_slow_subsequent_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
yield self._create_mock_chunk("First")
|
||||||
|
yield self._create_mock_chunk("Second")
|
||||||
|
yield self._create_mock_chunk("Third")
|
||||||
|
|
||||||
|
wrapped = with_first_token_timeout(
|
||||||
|
generator_with_slow_subsequent_chunks(),
|
||||||
|
timeout_seconds=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise, even though "time" passes beyond timeout after first token
|
||||||
|
chunks = list(wrapped)
|
||||||
|
assert len(chunks) == 3
|
||||||
|
|
||||||
|
def test_empty_generator_no_error(self):
|
||||||
|
"""Test that empty generator doesn't raise timeout error (no chunks to check)."""
|
||||||
|
|
||||||
|
def empty_generator() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
return
|
||||||
|
yield # unreachable, but makes this a generator
|
||||||
|
|
||||||
|
wrapped = with_first_token_timeout(empty_generator(), timeout_seconds=10)
|
||||||
|
chunks = list(wrapped)
|
||||||
|
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_exact_timeout_boundary(self, monkeypatch):
|
||||||
|
"""Test behavior at exact timeout boundary (should not raise when equal)."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_monotonic():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return 0.0
|
||||||
|
# Exactly at boundary: current_time - start_time = 10, timeout_seconds = 10
|
||||||
|
# Since we check > not >=, this should NOT raise
|
||||||
|
return 10.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||||
|
|
||||||
|
def generator() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
yield self._create_mock_chunk("Token at boundary")
|
||||||
|
|
||||||
|
wrapped = with_first_token_timeout(generator(), timeout_seconds=10)
|
||||||
|
|
||||||
|
# Should not raise because 10 is not > 10
|
||||||
|
chunks = list(wrapped)
|
||||||
|
assert len(chunks) == 1
|
||||||
|
|
||||||
|
def test_just_over_timeout_boundary(self, monkeypatch):
|
||||||
|
"""Test behavior just over timeout boundary (should raise)."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_monotonic():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return 0.0
|
||||||
|
# Just over boundary
|
||||||
|
return 10.001
|
||||||
|
|
||||||
|
monkeypatch.setattr(time, "monotonic", mock_monotonic)
|
||||||
|
|
||||||
|
def generator() -> Generator[LLMResultChunk, None, None]:
|
||||||
|
yield self._create_mock_chunk("Late token")
|
||||||
|
|
||||||
|
wrapped = with_first_token_timeout(generator(), timeout_seconds=10)
|
||||||
|
|
||||||
|
with pytest.raises(FirstTokenTimeoutError):
|
||||||
|
list(wrapped)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMNodeInvokeLLMWithTimeout:
|
||||||
|
"""Test cases for LLMNode.invoke_llm with first_token_timeout parameter."""
|
||||||
|
|
||||||
|
def test_invoke_llm_without_timeout(self):
|
||||||
|
"""Test invoke_llm works normally when first_token_timeout is None."""
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
|
||||||
|
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||||
|
mock_handle.return_value = iter([])
|
||||||
|
|
||||||
|
# Mock model_instance.invoke_llm to return empty generator
|
||||||
|
mock_model_instance = mock.MagicMock()
|
||||||
|
mock_model_instance.invoke_llm.return_value = iter([])
|
||||||
|
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||||
|
|
||||||
|
mock_node_data_model = mock.MagicMock()
|
||||||
|
mock_node_data_model.completion_params = {}
|
||||||
|
|
||||||
|
result = LLMNode.invoke_llm(
|
||||||
|
node_data_model=mock_node_data_model,
|
||||||
|
model_instance=mock_model_instance,
|
||||||
|
prompt_messages=[],
|
||||||
|
user_id="test-user",
|
||||||
|
structured_output_enabled=False,
|
||||||
|
structured_output=None,
|
||||||
|
file_saver=mock.MagicMock(),
|
||||||
|
file_outputs=[],
|
||||||
|
node_id="test-node",
|
||||||
|
node_type=mock.MagicMock(),
|
||||||
|
first_token_timeout=None, # No timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
list(result) # Consume generator
|
||||||
|
mock_handle.assert_called_once()
|
||||||
|
|
||||||
|
def test_invoke_llm_with_timeout_passes_to_model_instance(self):
|
||||||
|
"""Test invoke_llm passes first_token_timeout to model_instance.invoke_llm."""
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
|
||||||
|
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||||
|
mock_handle.return_value = iter([])
|
||||||
|
|
||||||
|
mock_model_instance = mock.MagicMock()
|
||||||
|
mock_model_instance.invoke_llm.return_value = iter([])
|
||||||
|
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||||
|
|
||||||
|
mock_node_data_model = mock.MagicMock()
|
||||||
|
mock_node_data_model.completion_params = {}
|
||||||
|
|
||||||
|
result = LLMNode.invoke_llm(
|
||||||
|
node_data_model=mock_node_data_model,
|
||||||
|
model_instance=mock_model_instance,
|
||||||
|
prompt_messages=[],
|
||||||
|
user_id="test-user",
|
||||||
|
structured_output_enabled=False,
|
||||||
|
structured_output=None,
|
||||||
|
file_saver=mock.MagicMock(),
|
||||||
|
file_outputs=[],
|
||||||
|
node_id="test-node",
|
||||||
|
node_type=mock.MagicMock(),
|
||||||
|
first_token_timeout=60, # With timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
list(result) # Consume generator
|
||||||
|
|
||||||
|
# Verify model_instance.invoke_llm was called with first_token_timeout
|
||||||
|
mock_model_instance.invoke_llm.assert_called_once()
|
||||||
|
call_kwargs = mock_model_instance.invoke_llm.call_args.kwargs
|
||||||
|
assert call_kwargs.get("first_token_timeout") == 60
|
||||||
|
|
||||||
|
def test_invoke_llm_with_zero_timeout_passes_zero(self):
|
||||||
|
"""Test invoke_llm passes zero timeout to model_instance."""
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
|
||||||
|
with mock.patch.object(LLMNode, "handle_invoke_result") as mock_handle:
|
||||||
|
mock_handle.return_value = iter([])
|
||||||
|
|
||||||
|
mock_model_instance = mock.MagicMock()
|
||||||
|
mock_model_instance.invoke_llm.return_value = iter([])
|
||||||
|
mock_model_instance.model_type_instance.get_model_schema.return_value = mock.MagicMock()
|
||||||
|
|
||||||
|
mock_node_data_model = mock.MagicMock()
|
||||||
|
mock_node_data_model.completion_params = {}
|
||||||
|
|
||||||
|
result = LLMNode.invoke_llm(
|
||||||
|
node_data_model=mock_node_data_model,
|
||||||
|
model_instance=mock_model_instance,
|
||||||
|
prompt_messages=[],
|
||||||
|
user_id="test-user",
|
||||||
|
structured_output_enabled=False,
|
||||||
|
structured_output=None,
|
||||||
|
file_saver=mock.MagicMock(),
|
||||||
|
file_outputs=[],
|
||||||
|
node_id="test-node",
|
||||||
|
node_type=mock.MagicMock(),
|
||||||
|
first_token_timeout=0, # Zero timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
list(result) # Consume generator
|
||||||
|
|
||||||
|
# Verify model_instance.invoke_llm was called with zero timeout
|
||||||
|
mock_model_instance.invoke_llm.assert_called_once()
|
||||||
|
call_kwargs = mock_model_instance.invoke_llm.call_args.kwargs
|
||||||
|
assert call_kwargs.get("first_token_timeout") == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryConfigIntegration:
|
||||||
|
"""Integration tests for RetryConfig with LLM node data."""
|
||||||
|
|
||||||
|
def test_retry_config_in_node_data(self):
|
||||||
|
"""Test RetryConfig can be properly configured in LLMNodeData."""
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
|
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||||
|
|
||||||
|
node_data = LLMNodeData(
|
||||||
|
title="Test LLM",
|
||||||
|
model=ModelConfig(
|
||||||
|
provider="openai",
|
||||||
|
name="gpt-4",
|
||||||
|
mode=LLMMode.CHAT,
|
||||||
|
completion_params={},
|
||||||
|
),
|
||||||
|
prompt_template=[],
|
||||||
|
context=ContextConfig(enabled=False),
|
||||||
|
structured_output_enabled=False,
|
||||||
|
retry_config=RetryConfig(
|
||||||
|
max_retries=3,
|
||||||
|
retry_interval=1000,
|
||||||
|
retry_enabled=True,
|
||||||
|
first_token_timeout=3000, # 3000ms = 3s
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_data.retry_config.max_retries == 3
|
||||||
|
assert node_data.retry_config.retry_enabled is True
|
||||||
|
assert node_data.retry_config.first_token_timeout == 3000
|
||||||
|
assert node_data.retry_config.first_token_timeout_seconds == 3.0
|
||||||
|
assert node_data.retry_config.has_first_token_timeout is True
|
||||||
|
|
||||||
|
def test_default_retry_config_in_node_data(self):
|
||||||
|
"""Test default RetryConfig in LLMNodeData."""
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
|
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||||
|
|
||||||
|
node_data = LLMNodeData(
|
||||||
|
title="Test LLM",
|
||||||
|
model=ModelConfig(
|
||||||
|
provider="openai",
|
||||||
|
name="gpt-4",
|
||||||
|
mode=LLMMode.CHAT,
|
||||||
|
completion_params={},
|
||||||
|
),
|
||||||
|
prompt_template=[],
|
||||||
|
context=ContextConfig(enabled=False),
|
||||||
|
structured_output_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have default RetryConfig
|
||||||
|
assert node_data.retry_config.max_retries == 0
|
||||||
|
assert node_data.retry_config.retry_enabled is False
|
||||||
|
assert node_data.retry_config.first_token_timeout == 0
|
||||||
|
assert node_data.retry_config.has_first_token_timeout is False
|
||||||
@@ -9,6 +9,10 @@ import Split from '@/app/components/workflow/nodes/_base/components/split'
|
|||||||
import { useRetryConfig } from './hooks'
|
import { useRetryConfig } from './hooks'
|
||||||
import s from './style.module.css'
|
import s from './style.module.css'
|
||||||
|
|
||||||
|
// Nodes that support first token timeout configuration
|
||||||
|
// These nodes internally call LLM and have streaming response characteristics
|
||||||
|
const LLM_RELATED_NODE_TYPES = ['llm', 'agent', 'parameter-extractor', 'question-classifier']
|
||||||
|
|
||||||
type RetryOnPanelProps = Pick<Node, 'id' | 'data'>
|
type RetryOnPanelProps = Pick<Node, 'id' | 'data'>
|
||||||
const RetryOnPanel = ({
|
const RetryOnPanel = ({
|
||||||
id,
|
id,
|
||||||
@@ -16,10 +20,14 @@ const RetryOnPanel = ({
|
|||||||
}: RetryOnPanelProps) => {
|
}: RetryOnPanelProps) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { handleRetryConfigChange } = useRetryConfig(id)
|
const { handleRetryConfigChange } = useRetryConfig(id)
|
||||||
const { retry_config } = data
|
const { retry_config, type } = data
|
||||||
|
|
||||||
|
// Check if this is an LLM-related node that supports first token timeout
|
||||||
|
const isLLMRelatedNode = LLM_RELATED_NODE_TYPES.includes(type)
|
||||||
|
|
||||||
const handleRetryEnabledChange = (value: boolean) => {
|
const handleRetryEnabledChange = (value: boolean) => {
|
||||||
handleRetryConfigChange({
|
handleRetryConfigChange({
|
||||||
|
...retry_config,
|
||||||
retry_enabled: value,
|
retry_enabled: value,
|
||||||
max_retries: retry_config?.max_retries || 3,
|
max_retries: retry_config?.max_retries || 3,
|
||||||
retry_interval: retry_config?.retry_interval || 1000,
|
retry_interval: retry_config?.retry_interval || 1000,
|
||||||
@@ -32,6 +40,7 @@ const RetryOnPanel = ({
|
|||||||
else if (value < 1)
|
else if (value < 1)
|
||||||
value = 1
|
value = 1
|
||||||
handleRetryConfigChange({
|
handleRetryConfigChange({
|
||||||
|
...retry_config,
|
||||||
retry_enabled: true,
|
retry_enabled: true,
|
||||||
max_retries: value,
|
max_retries: value,
|
||||||
retry_interval: retry_config?.retry_interval || 1000,
|
retry_interval: retry_config?.retry_interval || 1000,
|
||||||
@@ -44,12 +53,27 @@ const RetryOnPanel = ({
|
|||||||
else if (value < 100)
|
else if (value < 100)
|
||||||
value = 100
|
value = 100
|
||||||
handleRetryConfigChange({
|
handleRetryConfigChange({
|
||||||
|
...retry_config,
|
||||||
retry_enabled: true,
|
retry_enabled: true,
|
||||||
max_retries: retry_config?.max_retries || 3,
|
max_retries: retry_config?.max_retries || 3,
|
||||||
retry_interval: value,
|
retry_interval: value,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleFirstTokenTimeoutChange = (value: number) => {
|
||||||
|
if (value > 60000)
|
||||||
|
value = 60000
|
||||||
|
else if (value < 0)
|
||||||
|
value = 0
|
||||||
|
handleRetryConfigChange({
|
||||||
|
...retry_config,
|
||||||
|
retry_enabled: true,
|
||||||
|
max_retries: retry_config?.max_retries || 3,
|
||||||
|
retry_interval: retry_config?.retry_interval || 1000,
|
||||||
|
first_token_timeout: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="pt-2">
|
<div className="pt-2">
|
||||||
@@ -62,54 +86,76 @@ const RetryOnPanel = ({
|
|||||||
onChange={v => handleRetryEnabledChange(v)}
|
onChange={v => handleRetryEnabledChange(v)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{
|
{retry_config?.retry_enabled && (
|
||||||
retry_config?.retry_enabled && (
|
<div className="px-4 pb-2">
|
||||||
<div className="px-4 pb-2">
|
<div className="mb-1 flex w-full items-center">
|
||||||
<div className="mb-1 flex w-full items-center">
|
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.maxRetries', { ns: 'workflow' })}</div>
|
||||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.maxRetries', { ns: 'workflow' })}</div>
|
<Slider
|
||||||
|
className="mr-3 w-[108px]"
|
||||||
|
value={retry_config?.max_retries || 3}
|
||||||
|
onChange={handleMaxRetriesChange}
|
||||||
|
min={1}
|
||||||
|
max={10}
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
wrapperClassName="w-[100px]"
|
||||||
|
value={retry_config?.max_retries || 3}
|
||||||
|
onChange={e =>
|
||||||
|
handleMaxRetriesChange(Number.parseInt(e.currentTarget.value, 10) || 3)}
|
||||||
|
min={1}
|
||||||
|
max={10}
|
||||||
|
unit={t('nodes.common.retry.times', { ns: 'workflow' }) || ''}
|
||||||
|
className={s.input}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="mb-1 flex w-full items-center">
|
||||||
|
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.retryInterval', { ns: 'workflow' })}</div>
|
||||||
|
<Slider
|
||||||
|
className="mr-3 w-[108px]"
|
||||||
|
value={retry_config?.retry_interval || 1000}
|
||||||
|
onChange={handleRetryIntervalChange}
|
||||||
|
min={100}
|
||||||
|
max={5000}
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
wrapperClassName="w-[100px]"
|
||||||
|
value={retry_config?.retry_interval || 1000}
|
||||||
|
onChange={e =>
|
||||||
|
handleRetryIntervalChange(Number.parseInt(e.currentTarget.value, 10) || 1000)}
|
||||||
|
min={100}
|
||||||
|
max={5000}
|
||||||
|
unit={t('nodes.common.retry.ms', { ns: 'workflow' }) || ''}
|
||||||
|
className={s.input}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{/* First token timeout - only for LLM-related nodes */}
|
||||||
|
{isLLMRelatedNode && (
|
||||||
|
<div className="flex w-full items-center">
|
||||||
|
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.firstTokenTimeout', { ns: 'workflow' })}</div>
|
||||||
<Slider
|
<Slider
|
||||||
className="mr-3 w-[108px]"
|
className="mr-3 w-[108px]"
|
||||||
value={retry_config?.max_retries || 3}
|
value={retry_config?.first_token_timeout ?? 3000}
|
||||||
onChange={handleMaxRetriesChange}
|
onChange={handleFirstTokenTimeoutChange}
|
||||||
min={1}
|
min={0}
|
||||||
max={10}
|
max={60000}
|
||||||
/>
|
/>
|
||||||
<Input
|
<Input
|
||||||
type="number"
|
type="number"
|
||||||
wrapperClassName="w-[100px]"
|
wrapperClassName="w-[100px]"
|
||||||
value={retry_config?.max_retries || 3}
|
value={retry_config?.first_token_timeout ?? 3000}
|
||||||
onChange={e =>
|
onChange={e =>
|
||||||
handleMaxRetriesChange(Number.parseInt(e.currentTarget.value, 10) || 3)}
|
handleFirstTokenTimeoutChange(Number.parseInt(e.currentTarget.value, 10) || 0)}
|
||||||
min={1}
|
min={0}
|
||||||
max={10}
|
max={60000}
|
||||||
unit={t('nodes.common.retry.times', { ns: 'workflow' }) || ''}
|
|
||||||
className={s.input}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="flex items-center">
|
|
||||||
<div className="system-xs-medium-uppercase mr-2 grow text-text-secondary">{t('nodes.common.retry.retryInterval', { ns: 'workflow' })}</div>
|
|
||||||
<Slider
|
|
||||||
className="mr-3 w-[108px]"
|
|
||||||
value={retry_config?.retry_interval || 1000}
|
|
||||||
onChange={handleRetryIntervalChange}
|
|
||||||
min={100}
|
|
||||||
max={5000}
|
|
||||||
/>
|
|
||||||
<Input
|
|
||||||
type="number"
|
|
||||||
wrapperClassName="w-[100px]"
|
|
||||||
value={retry_config?.retry_interval || 1000}
|
|
||||||
onChange={e =>
|
|
||||||
handleRetryIntervalChange(Number.parseInt(e.currentTarget.value, 10) || 1000)}
|
|
||||||
min={100}
|
|
||||||
max={5000}
|
|
||||||
unit={t('nodes.common.retry.ms', { ns: 'workflow' }) || ''}
|
unit={t('nodes.common.retry.ms', { ns: 'workflow' }) || ''}
|
||||||
className={s.input}
|
className={s.input}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
)}
|
||||||
)
|
</div>
|
||||||
}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<Split className="mx-4 mt-2" />
|
<Split className="mx-4 mt-2" />
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -2,4 +2,6 @@ export type WorkflowRetryConfig = {
|
|||||||
max_retries: number
|
max_retries: number
|
||||||
retry_interval: number
|
retry_interval: number
|
||||||
retry_enabled: boolean
|
retry_enabled: boolean
|
||||||
|
// First token timeout for LLM nodes (seconds), 0 means no timeout
|
||||||
|
first_token_timeout?: number
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -433,6 +433,7 @@
|
|||||||
"nodes.common.memory.windowSize": "Window Size",
|
"nodes.common.memory.windowSize": "Window Size",
|
||||||
"nodes.common.outputVars": "Output Variables",
|
"nodes.common.outputVars": "Output Variables",
|
||||||
"nodes.common.pluginNotInstalled": "Plugin is not installed",
|
"nodes.common.pluginNotInstalled": "Plugin is not installed",
|
||||||
|
"nodes.common.retry.firstTokenTimeout": "First token timeout",
|
||||||
"nodes.common.retry.maxRetries": "max retries",
|
"nodes.common.retry.maxRetries": "max retries",
|
||||||
"nodes.common.retry.ms": "ms",
|
"nodes.common.retry.ms": "ms",
|
||||||
"nodes.common.retry.retries": "{{num}} Retries",
|
"nodes.common.retry.retries": "{{num}} Retries",
|
||||||
@@ -444,6 +445,8 @@
|
|||||||
"nodes.common.retry.retrySuccessful": "Retry successful",
|
"nodes.common.retry.retrySuccessful": "Retry successful",
|
||||||
"nodes.common.retry.retryTimes": "Retry {{times}} times on failure",
|
"nodes.common.retry.retryTimes": "Retry {{times}} times on failure",
|
||||||
"nodes.common.retry.retrying": "Retrying...",
|
"nodes.common.retry.retrying": "Retrying...",
|
||||||
|
"nodes.common.retry.seconds": "s",
|
||||||
|
"nodes.common.retry.timeoutDuration": "Timeout duration",
|
||||||
"nodes.common.retry.times": "times",
|
"nodes.common.retry.times": "times",
|
||||||
"nodes.common.typeSwitch.input": "Input value",
|
"nodes.common.typeSwitch.input": "Input value",
|
||||||
"nodes.common.typeSwitch.variable": "Use variable",
|
"nodes.common.typeSwitch.variable": "Use variable",
|
||||||
|
|||||||
@@ -433,6 +433,7 @@
|
|||||||
"nodes.common.memory.windowSize": "记忆窗口",
|
"nodes.common.memory.windowSize": "记忆窗口",
|
||||||
"nodes.common.outputVars": "输出变量",
|
"nodes.common.outputVars": "输出变量",
|
||||||
"nodes.common.pluginNotInstalled": "插件未安装",
|
"nodes.common.pluginNotInstalled": "插件未安装",
|
||||||
|
"nodes.common.retry.firstTokenTimeout": "首个 Token 超时",
|
||||||
"nodes.common.retry.maxRetries": "最大重试次数",
|
"nodes.common.retry.maxRetries": "最大重试次数",
|
||||||
"nodes.common.retry.ms": "毫秒",
|
"nodes.common.retry.ms": "毫秒",
|
||||||
"nodes.common.retry.retries": "{{num}} 重试次数",
|
"nodes.common.retry.retries": "{{num}} 重试次数",
|
||||||
@@ -444,6 +445,8 @@
|
|||||||
"nodes.common.retry.retrySuccessful": "重试成功",
|
"nodes.common.retry.retrySuccessful": "重试成功",
|
||||||
"nodes.common.retry.retryTimes": "失败时重试 {{times}} 次",
|
"nodes.common.retry.retryTimes": "失败时重试 {{times}} 次",
|
||||||
"nodes.common.retry.retrying": "重试中...",
|
"nodes.common.retry.retrying": "重试中...",
|
||||||
|
"nodes.common.retry.seconds": "秒",
|
||||||
|
"nodes.common.retry.timeoutDuration": "超时时长",
|
||||||
"nodes.common.retry.times": "次",
|
"nodes.common.retry.times": "次",
|
||||||
"nodes.common.typeSwitch.input": "输入值",
|
"nodes.common.typeSwitch.input": "输入值",
|
||||||
"nodes.common.typeSwitch.variable": "使用变量",
|
"nodes.common.typeSwitch.variable": "使用变量",
|
||||||
|
|||||||
Reference in New Issue
Block a user