mirror of
https://github.com/langgenius/dify.git
synced 2026-03-13 11:17:07 +00:00
Compare commits
2 Commits
deploy/dev
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd7c0c1802 | ||
|
|
e9271bf6d1 |
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@@ -17,8 +17,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -72,7 +72,7 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --reporter=json --reporter=agent --coverage
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
@@ -188,6 +188,7 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@ class WeaviateConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
|
||||
@@ -114,7 +114,6 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
|
||||
@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -193,8 +193,7 @@ class LLMGenerator:
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
@@ -280,8 +279,7 @@ class LLMGenerator:
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
error = str(e)
|
||||
error_step = "handle unexpected exception"
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
|
||||
@@ -113,26 +113,17 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential, as CredentialType or str; str values
|
||||
are normalized via CredentialType.of and may raise ValueError for invalid values.
|
||||
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
|
||||
empty list for CredentialType.UNAUTHORIZED or missing schemas.
|
||||
|
||||
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
|
||||
Raises ValueError for invalid credential types.
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
"""
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
|
||||
@@ -137,7 +137,6 @@ class ToolFileManager:
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
@@ -276,4 +276,7 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -4,8 +4,7 @@ class InvokeError(ValueError):
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
if description is not None:
|
||||
self.description = description
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
@@ -282,8 +282,7 @@ class ModelProviderFactory:
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@ dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.2",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.65",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.5.2",
|
||||
"celery~=5.6.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.24",
|
||||
@@ -35,30 +35,30 @@ dependencies = [
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.7.16",
|
||||
"markdown~=3.8.1",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.28.0",
|
||||
"opentelemetry-instrumentation==0.49b0",
|
||||
"opentelemetry-instrumentation-celery==0.49b0",
|
||||
"opentelemetry-instrumentation-flask==0.49b0",
|
||||
"opentelemetry-instrumentation-httpx==0.49b0",
|
||||
"opentelemetry-instrumentation-redis==0.49b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.40.0",
|
||||
"opentelemetry-instrumentation==0.61b0",
|
||||
"opentelemetry-instrumentation-celery==0.61b0",
|
||||
"opentelemetry-instrumentation-flask==0.61b0",
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-proto==1.28.0",
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.23.0",
|
||||
@@ -66,22 +66,22 @@ dependencies = [
|
||||
"pydantic-extra-types~=2.11.0",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pyjwt~=2.11.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"pypdfium2==5.6.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"python-dotenv==1.2.2",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.49.1",
|
||||
"starlette==0.52.1",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"yarl~=1.23.0",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.8.0",
|
||||
"sseclient-py~=1.9.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.2",
|
||||
@@ -120,7 +120,7 @@ dev = [
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.13.2",
|
||||
"testcontainers~=4.14.1",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
|
||||
@@ -60,6 +60,7 @@ VECTOR_STORE=weaviate
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin endpoints
|
||||
|
||||
Tests endpoint structure (method existence) for all plugin APIs, plus
|
||||
handler-level logic tests for representative non-streaming endpoints.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api.plugin.plugin import (
|
||||
PluginFetchAppInfoApi,
|
||||
PluginInvokeAppApi,
|
||||
PluginInvokeEncryptApi,
|
||||
PluginInvokeLLMApi,
|
||||
PluginInvokeLLMWithStructuredOutputApi,
|
||||
PluginInvokeModerationApi,
|
||||
PluginInvokeParameterExtractorNodeApi,
|
||||
PluginInvokeQuestionClassifierNodeApi,
|
||||
PluginInvokeRerankApi,
|
||||
PluginInvokeSpeech2TextApi,
|
||||
PluginInvokeSummaryApi,
|
||||
PluginInvokeTextEmbeddingApi,
|
||||
PluginInvokeToolApi,
|
||||
PluginInvokeTTSApi,
|
||||
PluginUploadFileRequestApi,
|
||||
)
|
||||
|
||||
|
||||
def _extract_raw_post(cls):
|
||||
"""Extract the raw post() method from a plugin endpoint class.
|
||||
|
||||
Plugin endpoint methods are wrapped by several decorators (get_user_tenant,
|
||||
setup_required, plugin_inner_api_only, plugin_data). These decorators
|
||||
use @wraps where possible. This helper ensures we retrieve the original
|
||||
post(self, user_model, tenant_model, payload) function by unwrapping
|
||||
and, if necessary, walking the closure of the innermost wrapper.
|
||||
"""
|
||||
bottom = inspect.unwrap(cls.post)
|
||||
|
||||
# If unwrap() didn't get us to the raw function (e.g. if a decorator
|
||||
# missed @wraps), try to extract it from the closure if it looks like
|
||||
# a plugin_data or similar wrapper that closes over 'view_func'.
|
||||
if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars:
|
||||
try:
|
||||
idx = bottom.__code__.co_freevars.index("view_func")
|
||||
return bottom.__closure__[idx].cell_contents
|
||||
except (AttributeError, TypeError, IndexError):
|
||||
pass
|
||||
|
||||
return bottom
|
||||
|
||||
|
||||
class TestPluginInvokeLLMApi:
|
||||
"""Test PluginInvokeLLMApi endpoint structure"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeLLMWithStructuredOutputApi:
|
||||
"""Test PluginInvokeLLMWithStructuredOutputApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMWithStructuredOutputApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTextEmbeddingApi:
|
||||
"""Test PluginInvokeTextEmbeddingApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTextEmbeddingApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeRerankApi:
|
||||
"""Test PluginInvokeRerankApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeRerankApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTTSApi:
|
||||
"""Test PluginInvokeTTSApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTTSApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeSpeech2TextApi:
|
||||
"""Test PluginInvokeSpeech2TextApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSpeech2TextApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeModerationApi:
|
||||
"""Test PluginInvokeModerationApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeModerationApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeToolApi:
|
||||
"""Test PluginInvokeToolApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeToolApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeParameterExtractorNodeApi:
|
||||
"""Test PluginInvokeParameterExtractorNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeParameterExtractorNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeQuestionClassifierNodeApi:
|
||||
"""Test PluginInvokeQuestionClassifierNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeQuestionClassifierNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeAppApi:
|
||||
"""Test PluginInvokeAppApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeAppApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeEncryptApi:
|
||||
"""Test PluginInvokeEncryptApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeEncryptApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() delegates to PluginEncrypter and returns model_dump output"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act — extract raw post() bypassing all decorators including plugin_data
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload)
|
||||
assert result["data"] == {"encrypted": "data"}
|
||||
assert result.get("error") == ""
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() catches exceptions and returns error response"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed")
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
assert "encrypt failed" in result["error"]
|
||||
|
||||
|
||||
class TestPluginInvokeSummaryApi:
|
||||
"""Test PluginInvokeSummaryApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSummaryApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginUploadFileRequestApi:
|
||||
"""Test PluginUploadFileRequestApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginUploadFileRequestApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin")
|
||||
def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask):
|
||||
"""Test that post() generates a signed URL and returns it"""
|
||||
# Arrange
|
||||
mock_get_url.return_value = "https://storage.example.com/signed-upload-url"
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-id"
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.filename = "test.pdf"
|
||||
mock_payload.mimetype = "application/pdf"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginUploadFileRequestApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_get_url.assert_called_once_with(
|
||||
filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id"
|
||||
)
|
||||
assert result["data"]["url"] == "https://storage.example.com/signed-upload-url"
|
||||
|
||||
|
||||
class TestPluginFetchAppInfoApi:
|
||||
"""Test PluginFetchAppInfoApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginFetchAppInfoApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation")
|
||||
def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask):
|
||||
"""Test that post() fetches app info and returns it"""
|
||||
# Arrange
|
||||
mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.app_id = "app-123"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginFetchAppInfoApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id")
|
||||
assert result["data"] == {"app_name": "My App", "mode": "chat"}
|
||||
@@ -1,305 +0,0 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.plugin.wraps import (
|
||||
TenantUserPayload,
|
||||
get_user,
|
||||
get_user_tenant,
|
||||
plugin_data,
|
||||
)
|
||||
|
||||
|
||||
class TestTenantUserPayload:
|
||||
"""Test TenantUserPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload passes validation"""
|
||||
data = {"tenant_id": "tenant123", "user_id": "user456"}
|
||||
payload = TenantUserPayload.model_validate(data)
|
||||
assert payload.tenant_id == "tenant123"
|
||||
assert payload.user_id == "user456"
|
||||
|
||||
def test_missing_tenant_id(self):
|
||||
"""Test missing tenant_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"user_id": "user456"})
|
||||
|
||||
def test_missing_user_id(self):
|
||||
"""Test missing user_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"tenant_id": "tenant123"})
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test get_user function"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test returning existing user when found by ID"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.query.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_anonymous_user_by_session_id(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test returning existing anonymous user by session_id"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "anonymous_session")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test creating new user when not found in database"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_new_user
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_use_default_session_id_when_user_id_none(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test using default session ID when user_id is None"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", None)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_raise_error_on_database_exception(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test raising ValueError when database operation fails"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
with pytest.raises(ValueError, match="user not found"):
|
||||
get_user("tenant123", "user123")
|
||||
|
||||
|
||||
class TestGetUserTenant:
|
||||
"""Test get_user_tenant decorator"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that decorator injects tenant_model and user_model into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user456"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
|
||||
def test_should_raise_error_when_tenant_id_missing(self, app: Flask):
|
||||
"""Test that Pydantic ValidationError is raised when tenant_id is missing from payload"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert - Pydantic validates payload before manual check
|
||||
with app.test_request_context(json={"user_id": "user456"}):
|
||||
with pytest.raises(ValidationError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_when_tenant_not_found(self, app: Flask):
|
||||
"""Test that ValueError is raised when tenant is not found"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="tenant not found"):
|
||||
protected_view()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that default session ID is used when user_id is empty string"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
|
||||
# Act - use empty string for user_id to trigger default logic
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
from models.model import DefaultEndUserSessionID
|
||||
|
||||
mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID)
|
||||
|
||||
|
||||
class PluginTestPayload:
|
||||
"""Simple test payload class"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.value = data.get("value")
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
return cls(data)
|
||||
|
||||
|
||||
class TestPluginData:
|
||||
"""Test plugin_data decorator"""
|
||||
|
||||
def test_should_inject_valid_payload(self, app: Flask):
|
||||
"""Test that valid payload is injected into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "test_data"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "test_data"
|
||||
|
||||
def test_should_raise_error_on_invalid_json(self, app: Flask):
|
||||
"""Test that ValueError is raised when JSON parsing fails"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert - Malformed JSON triggers ValueError
|
||||
with app.test_request_context(data="not valid json", content_type="application/json"):
|
||||
with pytest.raises(ValueError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_on_invalid_payload(self, app: Flask):
|
||||
"""Test that ValueError is raised when payload validation fails"""
|
||||
|
||||
# Arrange
|
||||
class InvalidPayload:
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
raise Exception("Validation failed")
|
||||
|
||||
@plugin_data(payload_type=InvalidPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"data": "test"}):
|
||||
with pytest.raises(ValueError, match="invalid payload"):
|
||||
protected_view()
|
||||
|
||||
def test_should_work_as_parameterized_decorator(self, app: Flask):
|
||||
"""Test that decorator works when used with parentheses"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "parameterized"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "parameterized"
|
||||
@@ -1,309 +0,0 @@
|
||||
"""
|
||||
Unit tests for inner_api auth decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.inner_api.wraps import (
|
||||
billing_inner_api_only,
|
||||
enterprise_inner_api_only,
|
||||
enterprise_inner_api_user_auth,
|
||||
plugin_inner_api_only,
|
||||
)
|
||||
|
||||
|
||||
class TestBillingInnerApiOnly:
|
||||
"""Test billing_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiOnly:
|
||||
"""Test enterprise_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiUserAuth:
|
||||
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
|
||||
|
||||
def test_should_pass_through_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that request passes through when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_header_missing(self, app: Flask):
|
||||
"""Test that request passes through when Authorization header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_format_invalid(self, app: Flask):
|
||||
"""Test that request passes through when Authorization format is invalid (no colon)"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"Authorization": "invalid_format"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask):
|
||||
"""Test that request passes through when HMAC signature is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act - use wrong signature
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_inject_user_when_hmac_signature_valid(self, app: Flask):
|
||||
"""Test that user is injected when HMAC signature is valid"""
|
||||
# Arrange
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user")
|
||||
|
||||
# Calculate valid HMAC signature
|
||||
user_id = "user123"
|
||||
inner_api_key = "valid_key"
|
||||
data_to_sign = f"DIFY {user_id}"
|
||||
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
|
||||
valid_signature = b64encode(signature.digest()).decode("utf-8")
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
class TestPluginInnerApiOnly:
|
||||
"""Test plugin_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when PLUGIN_DAEMON_KEY is set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask):
|
||||
"""Test that 404 is returned when PLUGIN_DAEMON_KEY is not set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_404_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
@@ -1,206 +0,0 @@
|
||||
"""
|
||||
Unit tests for inner_api mail module
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.mail import (
|
||||
BaseMail,
|
||||
BillingMail,
|
||||
EnterpriseMail,
|
||||
InnerMailPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestInnerMailPayload:
|
||||
"""Test InnerMailPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload_with_all_fields(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
"substitutions": {"key": "value"},
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions == {"key": "value"}
|
||||
|
||||
def test_valid_payload_without_substitutions(self):
|
||||
"""Test valid payload without optional substitutions"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions is None
|
||||
|
||||
def test_empty_to_list_fails_validation(self):
|
||||
"""Test that empty 'to' list fails validation due to min_length=1"""
|
||||
data = {
|
||||
"to": [],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_multiple_recipients_allowed(self):
|
||||
"""Test that multiple recipients are allowed"""
|
||||
data = {
|
||||
"to": ["user1@example.com", "user2@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert len(payload.to) == 2
|
||||
assert "user1@example.com" in payload.to
|
||||
assert "user2@example.com" in payload.to
|
||||
|
||||
def test_missing_to_field_fails_validation(self):
|
||||
"""Test that missing 'to' field fails validation"""
|
||||
data = {
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_subject_fails_validation(self):
|
||||
"""Test that missing 'subject' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_body_fails_validation(self):
|
||||
"""Test that missing 'body' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
|
||||
class TestBaseMail:
|
||||
"""Test BaseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BaseMail API instance"""
|
||||
return BaseMail()
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_sends_email_task(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends inner email task"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
json={
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
):
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Test Subject",
|
||||
body="Test Body",
|
||||
substitutions=None,
|
||||
)
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_with_substitutions(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends email with substitutions"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Hello {{name}}",
|
||||
"body": "Welcome {{name}}!",
|
||||
"substitutions": {"name": "John"},
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Hello {{name}}",
|
||||
body="Welcome {{name}}!",
|
||||
substitutions={"name": "John"},
|
||||
)
|
||||
|
||||
|
||||
class TestEnterpriseMail:
|
||||
"""Test EnterpriseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create EnterpriseMail API instance"""
|
||||
return EnterpriseMail()
|
||||
|
||||
def test_has_enterprise_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has enterprise_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
|
||||
assert enterprise_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
|
||||
|
||||
class TestBillingMail:
|
||||
"""Test BillingMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BillingMail API instance"""
|
||||
return BillingMail()
|
||||
|
||||
def test_has_billing_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that BillingMail has billing_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import billing_inner_api_only
|
||||
|
||||
assert billing_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that BillingMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
@@ -1,184 +0,0 @@
|
||||
"""
|
||||
Unit tests for inner_api workspace module
|
||||
|
||||
Tests Pydantic model validation and endpoint handler logic.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them and focus on business logic.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.workspace.workspace import (
|
||||
EnterpriseWorkspace,
|
||||
EnterpriseWorkspaceNoOwnerEmail,
|
||||
WorkspaceCreatePayload,
|
||||
WorkspaceOwnerlessPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspaceCreatePayload:
|
||||
"""Test WorkspaceCreatePayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"name": "My Workspace",
|
||||
"owner_email": "owner@example.com",
|
||||
}
|
||||
payload = WorkspaceCreatePayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
assert payload.owner_email == "owner@example.com"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {"owner_email": "owner@example.com"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
def test_missing_owner_email_fails_validation(self):
|
||||
"""Test that missing owner_email fails validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "owner_email" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestWorkspaceOwnerlessPayload:
|
||||
"""Test WorkspaceOwnerlessPayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with name passes validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
payload = WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspace:
|
||||
"""Test EnterpriseWorkspace API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspace()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that EnterpriseWorkspace has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace and assigns the owner account"""
|
||||
# Arrange
|
||||
mock_account = MagicMock()
|
||||
mock_account.email = "owner@example.com"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["name"] == "My Workspace"
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
|
||||
"""Test that post() returns 404 when the owner account does not exist"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "owner account not found."}, 404)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspaceNoOwnerEmail:
|
||||
"""Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspaceNoOwnerEmail()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace without an owner and returns expected fields"""
|
||||
# Arrange
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.encrypt_public_key = "pub-key"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.custom_config = None
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["encrypt_public_key"] == "pub-key"
|
||||
assert result["tenant"]["custom_config"] == {}
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
@@ -1,80 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class DummyPromptEntity:
|
||||
def __init__(self, first_prompt):
|
||||
self.first_prompt = first_prompt
|
||||
|
||||
|
||||
class DummyAgentConfig:
|
||||
def __init__(self, prompt_entity=None):
|
||||
self.prompt = prompt_entity
|
||||
|
||||
|
||||
class DummyAppConfig:
|
||||
def __init__(self, agent=None):
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class DummyScratchpadUnit:
|
||||
def __init__(
|
||||
self,
|
||||
final=False,
|
||||
thought=None,
|
||||
action_str=None,
|
||||
observation=None,
|
||||
agent_response=None,
|
||||
):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tool_factory():
|
||||
def _factory(name):
|
||||
return DummyTool(name)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_prompt_entity_factory():
|
||||
def _factory(first_prompt):
|
||||
return DummyPromptEntity(first_prompt)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent_config_factory():
|
||||
def _factory(prompt_entity=None):
|
||||
return DummyAgentConfig(prompt_entity)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_app_config_factory():
|
||||
def _factory(agent=None):
|
||||
return DummyAppConfig(agent)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_scratchpad_unit_factory():
|
||||
def _factory(**kwargs):
|
||||
return DummyScratchpadUnit(**kwargs)
|
||||
|
||||
return _factory
|
||||
@@ -1,255 +1,70 @@
|
||||
"""Unit tests for CotAgentOutputParser.
|
||||
|
||||
Verifies expected parsing behavior for streaming content and JSON payloads,
|
||||
including edge cases such as empty/non-string content and malformed JSON.
|
||||
Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real
|
||||
model output structures. Implementation under test:
|
||||
core.agent.output_parser.cot_output_parser.CotAgentOutputParser.
|
||||
"""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_action_class(mocker):
|
||||
mock_action = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
|
||||
mock_action,
|
||||
)
|
||||
return mock_action
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usage_dict():
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_chunk():
|
||||
def _make_chunk(content=None, usage=None):
|
||||
delta = SimpleNamespace(
|
||||
message=SimpleNamespace(content=content),
|
||||
usage=usage,
|
||||
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
|
||||
for i in range(len(text)):
|
||||
yield LLMResultChunk(
|
||||
model="model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
|
||||
)
|
||||
return SimpleNamespace(delta=delta)
|
||||
|
||||
return _make_chunk
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test Suite
|
||||
# ============================================================
|
||||
def test_cot_output_parser():
|
||||
test_cases = [
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with json
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with JSON
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# list
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block
|
||||
{
|
||||
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block and json
|
||||
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
|
||||
]
|
||||
|
||||
|
||||
class TestCotAgentOutputParser:
|
||||
"""Validate CotAgentOutputParser streaming + JSON parsing behavior.
|
||||
|
||||
Lifecycle: no explicit setup/teardown; relies on pytest fixtures for
|
||||
lightweight chunk/action doubles. Invariants: non-string/empty content
|
||||
yields no output, usage gets recorded when provided, and valid action JSON
|
||||
results in Action instantiation. Usage: invoke via pytest (e.g.,
|
||||
`pytest -k TestCotAgentOutputParser`).
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Basic streaming & usage
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_stream_plain_text(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("hello world")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "".join(result) == "hello world"
|
||||
|
||||
def test_stream_empty_string(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_stream_none_content(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk(None)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.parametrize("content", [123, 12.5, [], {}, object()])
|
||||
def test_non_string_content(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_usage_update(self, make_chunk, usage_dict) -> None:
|
||||
usage_data = {"tokens": 99}
|
||||
chunks = [make_chunk("abc", usage=usage_data)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert usage_dict["usage"] == usage_data
|
||||
|
||||
# --------------------------------------------------------
|
||||
# JSON parsing (direct + streaming)
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '{"action": "search", "input": "query"}'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="search", action_input="query")
|
||||
|
||||
def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '[{"action": "lookup", "input": "abc"}]'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None:
|
||||
content = '{"foo": "bar"}'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# Expect the serialized JSON to be yielded as a single element.
|
||||
assert result == [json.dumps({"foo": "bar"})]
|
||||
|
||||
def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None:
|
||||
content = "{invalid json}"
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert any("invalid json" in str(r) for r in result)
|
||||
|
||||
def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
chunks = [
|
||||
make_chunk('{"action": '),
|
||||
make_chunk('"multi", '),
|
||||
make_chunk('"input": "step"}'),
|
||||
]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="multi", action_input="step")
|
||||
|
||||
def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('{"foo": "bar"')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('{"foo": "bar"' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Code block JSON extraction
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = """```json
|
||||
{"action": "lookup", "input": "abc"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
# Multiple JSON objects inside single code fence (invalid combined JSON)
|
||||
# Parser should safely ignore invalid combined block
|
||||
content = """```json
|
||||
{"action": "a1", "input": "x"}
|
||||
{"action": "a2", "input": "y"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# No valid parsed action expected due to invalid combined JSON
|
||||
assert mock_action_class.call_count == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None:
|
||||
content = """```json
|
||||
{invalid}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result
|
||||
|
||||
def test_unclosed_code_block(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('```json {"a":1}')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('```json {"a":1}' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Action / Thought prefix handling
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
" action: something",
|
||||
" ACTION: something",
|
||||
" thought: reasoning",
|
||||
" THOUGHT: reasoning",
|
||||
],
|
||||
)
|
||||
def test_prefix_handling(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
joined = "".join(str(item) for item in result)
|
||||
expected_word = "something" if "action:" in content.lower() else "reasoning"
|
||||
assert expected_word in joined
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("xaction: test")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "x" in "".join(map(str, result))
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Mixed streaming scenarios
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = 'start {"action": "mix", "input": "1"} end'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# JSON action should be parsed
|
||||
mock_action_class.assert_called_once()
|
||||
# Ensure surrounding text is streamed (character-level)
|
||||
joined = "".join(str(r) for r in result if not isinstance(r, MagicMock))
|
||||
assert "start" in joined
|
||||
assert "end" in joined
|
||||
|
||||
def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert mock_action_class.call_count == 2
|
||||
|
||||
def test_backtick_noise(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("text with ` random ` backticks")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "text with" in "".join(result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Boundary & edge inputs
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
"```",
|
||||
"{",
|
||||
"}",
|
||||
"```json",
|
||||
"action:",
|
||||
"thought:",
|
||||
" ",
|
||||
],
|
||||
)
|
||||
def test_edge_inputs(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
joined = "".join(result)
|
||||
if content == " ":
|
||||
assert result == [] or joined == content
|
||||
if content in {"```", "{", "}", "```json"}:
|
||||
assert content in joined
|
||||
if content.lower() in {"action:", "thought:"}:
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
parser = CotAgentOutputParser()
|
||||
usage_dict = {}
|
||||
for test_case in test_cases:
|
||||
# mock llm_response as a generator by text
|
||||
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
|
||||
results = parser.handle_react_stream_output(llm_response, usage_dict)
|
||||
output = ""
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
|
||||
|
||||
class DummyStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Concrete implementation for testing BaseAgentStrategy
|
||||
"""
|
||||
|
||||
def __init__(self, return_values=None, raise_exception=None):
|
||||
self.return_values = return_values or []
|
||||
self.raise_exception = raise_exception
|
||||
self.received_args = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id=None,
|
||||
app_id=None,
|
||||
message_id=None,
|
||||
credentials=None,
|
||||
) -> Generator:
|
||||
self.received_args = (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
credentials,
|
||||
)
|
||||
|
||||
if self.raise_exception:
|
||||
raise self.raise_exception
|
||||
|
||||
yield from self.return_values
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInstantiation:
|
||||
def test_cannot_instantiate_abstract_class(self) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
BaseAgentStrategy()
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInvoke:
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
return MagicMock(name="AgentInvokeMessage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
return MagicMock(name="InvokeCredentials")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("params", "user_id", "conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
({"key": "value"}, "user1", "conv1", "app1", "msg1"),
|
||||
({}, "user2", None, None, None),
|
||||
({"a": 1}, "", "", "", ""),
|
||||
({"nested": {"x": 1}}, "user3", None, "app3", None),
|
||||
],
|
||||
)
|
||||
def test_invoke_success(
|
||||
self,
|
||||
mock_message,
|
||||
mock_credentials,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[mock_message])
|
||||
|
||||
# Act
|
||||
result = list(
|
||||
strategy.invoke(
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
credentials=mock_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == [mock_message]
|
||||
assert strategy.received_args == (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
mock_credentials,
|
||||
)
|
||||
|
||||
def test_invoke_multiple_yields(self, mock_message) -> None:
|
||||
# Arrange
|
||||
messages = [mock_message, MagicMock(), MagicMock()]
|
||||
strategy = DummyStrategy(return_values=messages)
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == messages
|
||||
|
||||
def test_invoke_empty_generator(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_invoke_propagates_exception(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(raise_exception=ValueError("failure"))
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="failure"):
|
||||
list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_params",
|
||||
[
|
||||
None,
|
||||
"",
|
||||
123,
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None:
|
||||
"""
|
||||
Base class does not validate types — ensure pass-through behavior
|
||||
"""
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params=invalid_params, user_id="user"))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_invoke_none_user_id(self) -> None:
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params={}, user_id=None))
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestBaseAgentStrategyGetParameters:
|
||||
def test_get_parameters_default_empty_list(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
result = strategy.get_parameters()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result == []
|
||||
|
||||
def test_get_parameters_returns_new_list_each_time(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
|
||||
first = strategy.get_parameters()
|
||||
second = strategy.get_parameters()
|
||||
|
||||
assert first == second == []
|
||||
assert first is not second
|
||||
@@ -1,272 +0,0 @@
|
||||
# File: tests/unit_tests/core/agent/strategy/test_plugin.py
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
|
||||
# ============================================================
|
||||
# Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parameter():
|
||||
def _factory(name="param", return_value="initialized"):
|
||||
param = MagicMock()
|
||||
param.name = name
|
||||
param.init_frontend_parameter = MagicMock(return_value=return_value)
|
||||
return param
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_declaration(mock_parameter):
|
||||
param1 = mock_parameter("param1", "init1")
|
||||
param2 = mock_parameter("param2", "init2")
|
||||
|
||||
identity = MagicMock()
|
||||
identity.provider = "provider_x"
|
||||
identity.name = "strategy_x"
|
||||
|
||||
declaration = MagicMock()
|
||||
declaration.parameters = [param1, param2]
|
||||
declaration.identity = identity
|
||||
|
||||
return declaration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(mock_declaration):
|
||||
return PluginAgentStrategy(
|
||||
tenant_id="tenant_123",
|
||||
declaration=mock_declaration,
|
||||
meta_version="v1",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Initialization Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestPluginAgentStrategyInitialization:
|
||||
def test_init_sets_attributes(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version="meta_v",
|
||||
)
|
||||
|
||||
assert strategy.tenant_id == "tenant_test"
|
||||
assert strategy.declaration == mock_declaration
|
||||
assert strategy.meta_version == "meta_v"
|
||||
|
||||
def test_init_meta_version_none(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version=None,
|
||||
)
|
||||
|
||||
assert strategy.meta_version is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetParameters:
|
||||
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
|
||||
result = strategy.get_parameters()
|
||||
assert result == mock_declaration.parameters
|
||||
|
||||
|
||||
# ============================================================
|
||||
# initialize_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInitializeParameters:
|
||||
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
|
||||
params = {"param1": "value1"}
|
||||
|
||||
result = strategy.initialize_parameters(params.copy())
|
||||
|
||||
assert result["param1"] == "init1"
|
||||
assert result["param2"] == "init2"
|
||||
|
||||
mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1")
|
||||
mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_params",
|
||||
[
|
||||
{},
|
||||
{"param1": None},
|
||||
{"param1": ""},
|
||||
{"param1": 0},
|
||||
{"param1": []},
|
||||
{"param1": {}, "param2": "value"},
|
||||
],
|
||||
)
|
||||
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
|
||||
result = strategy.initialize_parameters(input_params.copy())
|
||||
|
||||
for param in strategy.declaration.parameters:
|
||||
assert param.name in result
|
||||
|
||||
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.initialize_parameters(None)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# _invoke Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInvoke:
|
||||
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mock_convert = mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={"converted": True},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={"param1": "value"},
|
||||
user_id="user_1",
|
||||
conversation_id="conv_1",
|
||||
app_id="app_1",
|
||||
message_id="msg_1",
|
||||
credentials=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["msg1", "msg2"]
|
||||
mock_convert.assert_called_once()
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
call_kwargs = mock_manager.invoke.call_args.kwargs
|
||||
assert call_kwargs["tenant_id"] == "tenant_123"
|
||||
assert call_kwargs["user_id"] == "user_1"
|
||||
assert call_kwargs["agent_provider"] == "provider_x"
|
||||
assert call_kwargs["agent_strategy"] == "strategy_x"
|
||||
assert call_kwargs["agent_params"] == {"converted": True}
|
||||
assert call_kwargs["conversation_id"] == "conv_1"
|
||||
assert call_kwargs["app_id"] == "app_1"
|
||||
assert call_kwargs["message_id"] == "msg_1"
|
||||
assert call_kwargs["context"] is not None
|
||||
|
||||
def test_invoke_with_credentials(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
# Patch PluginInvokeContext to bypass pydantic validation
|
||||
mock_context = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginInvokeContext",
|
||||
return_value=mock_context,
|
||||
)
|
||||
|
||||
credentials = MagicMock()
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
(None, None, None),
|
||||
("conv", None, None),
|
||||
(None, "app", None),
|
||||
(None, None, "msg"),
|
||||
],
|
||||
)
|
||||
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
side_effect=ValueError("conversion failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
|
||||
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
@@ -1,802 +0,0 @@
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.agent.base_agent_runner as module
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
|
||||
# ==========================================================
|
||||
# Fixtures
|
||||
# ==========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mocker):
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, mock_db_session):
|
||||
r = BaseAgentRunner.__new__(BaseAgentRunner)
|
||||
r.tenant_id = "tenant"
|
||||
r.user_id = "user"
|
||||
r.agent_thought_count = 0
|
||||
r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1")
|
||||
r.app_config = mocker.MagicMock()
|
||||
r.app_config.app_id = "app1"
|
||||
r.app_config.agent = None
|
||||
r.dataset_tools = []
|
||||
r.application_generate_entity = mocker.MagicMock(invoke_from="test")
|
||||
r._current_thoughts = []
|
||||
return r
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _repack_app_generate_entity
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestRepack:
|
||||
def test_sets_empty_if_none(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = None
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == ""
|
||||
|
||||
def test_keeps_existing(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = "abc"
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == "abc"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# update_prompt_message_tool
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestUpdatePromptTool:
|
||||
def build_param(self, mocker, **kwargs):
|
||||
p = mocker.MagicMock()
|
||||
p.form = kwargs.get("form")
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
p.type = mock_type
|
||||
|
||||
p.name = kwargs.get("name", "p1")
|
||||
p.llm_description = "desc"
|
||||
p.input_schema = kwargs.get("input_schema")
|
||||
p.options = kwargs.get("options")
|
||||
p.required = kwargs.get("required", False)
|
||||
return p
|
||||
|
||||
def test_skip_non_llm(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form="NOT_LLM")
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_enum_and_required(self, runner, mocker):
|
||||
option = mocker.MagicMock(value="opt1")
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
options=[option],
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool = mocker.MagicMock()
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert "p1" in result.parameters["required"]
|
||||
|
||||
def test_skip_file_type_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
|
||||
param.type = module.ToolParameter.ToolParameterType.FILE
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_duplicate_required_not_duplicated(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["required"].count("p1") == 1
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# create_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestCreateAgentThought:
|
||||
def test_with_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=10)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"])
|
||||
assert result == "10"
|
||||
assert runner.agent_thought_count == 1
|
||||
|
||||
def test_without_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=11)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", [])
|
||||
assert result == "11"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# save_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestSaveAgentThought:
|
||||
def setup_agent(self, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;tool2"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
return agent
|
||||
|
||||
def test_not_found(self, runner, mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError):
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
def test_full_update(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mock_label = mocker.MagicMock()
|
||||
mock_label.to_dict.return_value = {"en_US": "label"}
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label)
|
||||
|
||||
usage = mocker.MagicMock(
|
||||
prompt_tokens=1,
|
||||
prompt_price_unit=Decimal("0.1"),
|
||||
prompt_unit_price=Decimal("0.1"),
|
||||
completion_tokens=2,
|
||||
completion_price_unit=Decimal("0.2"),
|
||||
completion_unit_price=Decimal("0.2"),
|
||||
total_tokens=3,
|
||||
total_price=Decimal("0.3"),
|
||||
)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1;tool2",
|
||||
{"a": 1},
|
||||
"thought",
|
||||
{"b": 2},
|
||||
{"meta": 1},
|
||||
"answer",
|
||||
["f1"],
|
||||
usage,
|
||||
)
|
||||
|
||||
assert agent.answer == "answer"
|
||||
assert agent.tokens == 3
|
||||
assert "tool1" in json.loads(agent.tool_labels_str)
|
||||
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
agent.tool = "unknown_tool"
|
||||
mock_db_session.scalar.return_value = agent
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "unknown_tool" in labels
|
||||
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
bad_obj = MagicMock()
|
||||
bad_obj.__str__.return_value = "bad"
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
bad_obj,
|
||||
None,
|
||||
bad_obj,
|
||||
bad_obj,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
{"a": 1},
|
||||
None,
|
||||
{"b": 2},
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_user_prompt
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeUserPrompt:
|
||||
def test_no_files(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[])
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_history
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeHistory:
|
||||
def test_empty(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
|
||||
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input="invalid",
|
||||
observation="invalid",
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=";", thought="thinking")
|
||||
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=json.dumps({"tool1": {"x": 1}}),
|
||||
observation=json.dumps({"tool1": "obs"}),
|
||||
thought="thinking",
|
||||
)
|
||||
|
||||
msg = mocker.MagicMock(
|
||||
id="m100",
|
||||
agent_thoughts=[thought],
|
||||
answer=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _convert_tool_to_prompt_message_tool (new coverage)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestConvertToolToPromptMessageTool:
|
||||
def test_basic_conversion(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
runtime_param = mocker.MagicMock()
|
||||
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param.name = "param1"
|
||||
runtime_param.llm_description = "desc"
|
||||
runtime_param.required = True
|
||||
runtime_param.input_schema = None
|
||||
runtime_param.options = None
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
runtime_param.type = mock_type
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert entity == tool_entity
|
||||
|
||||
def test_full_conversion_multiple_params(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
# LLM param with input_schema override
|
||||
param1 = mocker.MagicMock()
|
||||
param1.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param1.name = "p1"
|
||||
param1.llm_description = "desc"
|
||||
param1.required = True
|
||||
param1.input_schema = {"type": "integer"}
|
||||
param1.options = None
|
||||
param1.type = mocker.MagicMock()
|
||||
|
||||
# SYSTEM_FILES param should be skipped
|
||||
param2 = mocker.MagicMock()
|
||||
param2.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param2.name = "file_param"
|
||||
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert entity == tool_entity
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _init_prompt_tools additional branches
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestInitPromptToolsExtended:
|
||||
def test_agent_tool_branch(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="agent_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert "agent_tool" in tools
|
||||
|
||||
def test_exception_in_conversion(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="bad_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert tools == {}
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestAdditionalCoverage:
|
||||
def test_update_prompt_with_input_schema(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "p1"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.options = None
|
||||
param.input_schema = {"type": "number"}
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
param.type = mock_type
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"]["p1"]["type"] == "number"
|
||||
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {"tool1": {"en_US": "existing"}}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert labels["tool1"]["en_US"] == "existing"
|
||||
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
|
||||
assert agent.tool_meta_str == "meta_string"
|
||||
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
assert prompt is not None
|
||||
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"])
|
||||
mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock())
|
||||
|
||||
mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw))
|
||||
mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result is not None
|
||||
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=None, thought="thinking")
|
||||
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1;tool2",
|
||||
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
|
||||
observation=json.dumps({"tool1": "o1", "tool2": "o2"}),
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
# ================= Additional Surgical Coverage =================
|
||||
|
||||
def test_convert_tool_select_enum_branch(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = True
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert prompt_tool is not None
|
||||
|
||||
|
||||
class TestConvertDatasetRetrieverTool:
|
||||
def test_required_param_added(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
|
||||
assert prompt is not None
|
||||
|
||||
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"])
|
||||
|
||||
llm = mocker.MagicMock()
|
||||
llm.get_model_schema.return_value = mocker.MagicMock(
|
||||
features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION]
|
||||
)
|
||||
model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c")
|
||||
|
||||
app_config = mocker.MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.agent = None
|
||||
app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"})
|
||||
app_config.additional_features = mocker.MagicMock(show_retrieve_source=True)
|
||||
|
||||
app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"])
|
||||
message = mocker.MagicMock(id="msg1", conversation_id="conv1")
|
||||
|
||||
runner = BaseAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=app_generate,
|
||||
conversation=mocker.MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=mocker.MagicMock(),
|
||||
config=mocker.MagicMock(),
|
||||
queue_manager=mocker.MagicMock(),
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert runner.stream_tool_call is True
|
||||
assert runner.files == ["file1"]
|
||||
assert runner.dataset_tools == ["ds_tool"]
|
||||
assert runner.agent_thought_count == 2
|
||||
|
||||
|
||||
class TestBaseAgentRunnerCoverage:
|
||||
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = "NOT_LLM"
|
||||
param.type = mocker.MagicMock()
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert prompt_tool.parameters["properties"] == {}
|
||||
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
|
||||
dataset_tool = mocker.MagicMock()
|
||||
dataset_tool.entity.identity.name = "ds"
|
||||
runner.dataset_tools = [dataset_tool]
|
||||
|
||||
mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock())
|
||||
|
||||
tools, prompt_tools = runner._init_prompt_tools()
|
||||
|
||||
assert tools["ds"] == dataset_tool
|
||||
assert len(prompt_tools) == 1
|
||||
|
||||
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
|
||||
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
tool_input = {"a": 1}
|
||||
observation = {"b": 2}
|
||||
tool_meta = {"c": 3}
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def dumps_side_effect(value, *args, **kwargs):
|
||||
if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False:
|
||||
raise TypeError("fail")
|
||||
return real_dumps(value, *args, **kwargs)
|
||||
|
||||
mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1",
|
||||
tool_input,
|
||||
None,
|
||||
observation,
|
||||
tool_meta,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
assert isinstance(agent.tool_meta_str, str)
|
||||
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;;"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "" not in labels
|
||||
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
|
||||
system_message = module.SystemPromptMessage(content="sys")
|
||||
|
||||
result = runner.organize_agent_history([system_message])
|
||||
|
||||
assert system_message in result
|
||||
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=None,
|
||||
observation=None,
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"organize_agent_user_prompt",
|
||||
return_value=module.UserPromptMessage(content="user"),
|
||||
)
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
|
||||
assert any(isinstance(item, module.ToolPromptMessage) for item in result)
|
||||
@@ -1,551 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
|
||||
|
||||
class DummyRunner(CotAgentRunner):
|
||||
"""Concrete implementation for testing abstract methods."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
# Minimal required defaults
|
||||
self.history_prompt_messages = []
|
||||
self.memory = None
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
return_value=[],
|
||||
)
|
||||
# Prepare required constructor dependencies for BaseAgentRunner
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock()
|
||||
application_generate_entity.model_conf.stop = []
|
||||
application_generate_entity.model_conf.provider = "openai"
|
||||
application_generate_entity.model_conf.parameters = {}
|
||||
application_generate_entity.trace_manager = None
|
||||
application_generate_entity.invoke_from = "test"
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock()
|
||||
app_config.agent.max_iteration = 1
|
||||
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.model = "test-model"
|
||||
|
||||
queue_manager = MagicMock()
|
||||
message = MagicMock()
|
||||
|
||||
runner = DummyRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=model_config,
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Patch internal methods to isolate behavior
|
||||
runner._repack_app_generate_entity = MagicMock()
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.create_agent_thought = MagicMock(return_value="thought-id")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
class TestFillInputs:
|
||||
@pytest.mark.parametrize(
|
||||
("instruction", "inputs", "expected"),
|
||||
[
|
||||
("Hello {{name}}", {"name": "John"}, "Hello John"),
|
||||
("No placeholders", {"name": "John"}, "No placeholders"),
|
||||
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
|
||||
("{{x}}", {"x": None}, "None"),
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
action_str="action1",
|
||||
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
|
||||
observation="obs1",
|
||||
)
|
||||
sp2 = AgentScratchpadUnit(
|
||||
agent_response="final",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
|
||||
observation=None,
|
||||
)
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
# Simulate final state via action name
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
action=None,
|
||||
observation="obs",
|
||||
)
|
||||
# Non-final state: provide a non-final action
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Thought:" in result
|
||||
assert "Action:" in result
|
||||
assert "Observation:" in result
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("result", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, [])
|
||||
assert response == "result"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
)
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
|
||||
usage_1 = LLMUsage.empty_usage()
|
||||
usage_1.prompt_tokens = 1
|
||||
usage_1.completion_tokens = 1
|
||||
usage_1.total_tokens = 2
|
||||
usage_1.prompt_price = 1
|
||||
usage_1.completion_price = 1
|
||||
usage_1.total_price = 2
|
||||
|
||||
usage_2 = LLMUsage.empty_usage()
|
||||
usage_2.prompt_tokens = 1
|
||||
usage_2.completion_tokens = 1
|
||||
usage_2.total_tokens = 2
|
||||
usage_2.prompt_price = 1
|
||||
usage_2.completion_price = 1
|
||||
usage_2.total_price = 2
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
handle_output = mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[
|
||||
[action],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_side_effect(chunks, usage_dict):
|
||||
call_index = handle_output.call_count
|
||||
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
|
||||
return [action] if call_index == 1 else []
|
||||
|
||||
handle_output.side_effect = _handle_side_effect
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
final_usage = results[-1].delta.usage
|
||||
assert final_usage is not None
|
||||
assert final_usage.prompt_tokens == 2
|
||||
assert final_usage.completion_tokens == 2
|
||||
assert final_usage.total_tokens == 4
|
||||
assert final_usage.prompt_price == 2
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
# First iteration → action
|
||||
# Second iteration → no action (empty list)
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[[action], []],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.app_config.agent.max_iteration = 5
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
|
||||
|
||||
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
|
||||
|
||||
runner.history_prompt_messages = [assistant, tool_msg]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
runner.history_prompt_messages = [assistant]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
|
||||
runner._init_react_state("new-query")
|
||||
|
||||
assert runner._query == "new-query"
|
||||
assert runner._agent_scratchpad == []
|
||||
assert runner._historic_prompt_messages == ["historic"]
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
|
||||
)
|
||||
|
||||
message_file_ids = []
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
|
||||
|
||||
assert response == "ok"
|
||||
assert message_file_ids == ["file1"]
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
|
||||
# Should silently continue on exception
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
|
||||
runner.history_prompt_messages = [user_message]
|
||||
|
||||
mock_transform = mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
)
|
||||
mock_transform.return_value.get_prompt.return_value = ["final"]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
return_value=mock_transform,
|
||||
)
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=["thinking"],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
# Remove invalid branch: Pydantic enforces str|dict for action_input
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "12345"
|
||||
@@ -1,215 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyAgentConfig,
|
||||
DummyAppConfig,
|
||||
DummyTool,
|
||||
)
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyPromptEntity as DummyPrompt,
|
||||
)
|
||||
|
||||
|
||||
class DummyFileUploadConfig:
|
||||
def __init__(self, image_config=None):
|
||||
self.image_config = image_config
|
||||
|
||||
|
||||
class DummyImageConfig:
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, file_upload_config=None):
|
||||
self.file_upload_config = file_upload_config
|
||||
|
||||
|
||||
class DummyUnit:
|
||||
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
|
||||
runner._instruction = "test_instruction"
|
||||
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
|
||||
runner._query = "user query"
|
||||
runner._agent_scratchpad = []
|
||||
runner.files = []
|
||||
runner.application_generate_entity = DummyGenerateEntity()
|
||||
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
|
||||
return runner
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_chat_agent_runner.jsonable_encoder",
|
||||
return_value=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
|
||||
result = runner._organize_system_prompt()
|
||||
|
||||
assert "test_instruction" in result.content
|
||||
assert "tool1" in result.content
|
||||
assert "tool2" in result.content
|
||||
assert "tool1, tool2" in result.content
|
||||
|
||||
def test_organize_system_prompt_missing_agent(self, runner):
|
||||
runner.app_config = DummyAppConfig(agent=None)
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
def test_organize_system_prompt_missing_prompt(self, runner):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
|
||||
def test_organize_user_query_no_files(self, runner, files):
|
||||
runner.files = files
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "query"
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
|
||||
image_config = DummyImageConfig(detail="high")
|
||||
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
|
||||
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assert "system" in result
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(final=True, agent_response="done")
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(
|
||||
final=False,
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
observation="observe",
|
||||
)
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: thinking" in combined
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
units = [
|
||||
DummyUnit(final=False, thought="t1"),
|
||||
DummyUnit(final=True, agent_response="done"),
|
||||
]
|
||||
runner._agent_scratchpad = units
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: t1" in combined
|
||||
assert "Final Answer: done" in combined
|
||||
@@ -1,234 +0,0 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, dummy_tool_factory):
|
||||
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
|
||||
|
||||
runner._instruction = "Test instruction"
|
||||
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
|
||||
runner._query = "What is Python?"
|
||||
runner._agent_scratchpad = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_completion_agent_runner.jsonable_encoder",
|
||||
side_effect=lambda tools: [{"name": t.name} for t in tools],
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_instruction_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
)
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
result = runner._organize_instruction_prompt()
|
||||
|
||||
assert "Test instruction" in result
|
||||
assert "toolA" in result
|
||||
assert "toolB" in result
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_historic_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[user_msg, assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_prompt_messages Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
|
||||
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
|
||||
content = result[0].content
|
||||
|
||||
assert "History" in content
|
||||
assert "Thought: Thinking" in content
|
||||
assert "Action: Act" in content
|
||||
assert "Observation: Obs" in content
|
||||
assert "Final Answer: Done" in content
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = None
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert "Question: What is Python?" in result[0].content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("thought", "action", "observation"),
|
||||
[
|
||||
("T", None, None),
|
||||
("T", "A", None),
|
||||
("T", None, "O"),
|
||||
],
|
||||
)
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(
|
||||
final=False,
|
||||
thought=thought,
|
||||
action_str=action,
|
||||
observation=observation,
|
||||
)
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
content = result[0].content
|
||||
|
||||
assert "Thought:" in content
|
||||
if action:
|
||||
assert "Action:" in content
|
||||
if observation:
|
||||
assert "Observation:" in content
|
||||
@@ -1,452 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
|
||||
# ==============================
|
||||
# Dummy Helper Classes
|
||||
# ==============================
|
||||
|
||||
|
||||
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = pt
|
||||
usage.completion_tokens = ct
|
||||
usage.total_tokens = tt
|
||||
usage.prompt_price = 0
|
||||
usage.completion_price = 0
|
||||
usage.total_price = 0
|
||||
return usage
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
|
||||
self.content: str | None = content
|
||||
self.tool_calls: list[Any] = tool_calls or []
|
||||
|
||||
|
||||
class DummyDelta:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
|
||||
|
||||
|
||||
class DummyResult:
|
||||
def __init__(
|
||||
self,
|
||||
message: DummyMessage | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
prompt_messages: list[DummyMessage] | None = None,
|
||||
):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
self.prompt_messages: list[DummyMessage] = prompt_messages or []
|
||||
self.system_fingerprint: str = ""
|
||||
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Patch streaming chunk models to avoid validation on dummy message objects
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock(max_iteration=2)
|
||||
app_config.prompt_template = MagicMock(simple_prompt_template="system")
|
||||
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
|
||||
application_generate_entity.trace_manager = MagicMock()
|
||||
application_generate_entity.invoke_from = "test"
|
||||
application_generate_entity.app_config = MagicMock(app_id="app")
|
||||
application_generate_entity.file_upload_config = None
|
||||
|
||||
queue_manager = MagicMock()
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
|
||||
message = MagicMock(id="msg1")
|
||||
conversation = MagicMock(id="conv1")
|
||||
|
||||
runner = FunctionCallAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=MagicMock(),
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Manually inject required attributes normally set by BaseAgentRunner
|
||||
runner.tenant_id = "tenant"
|
||||
runner.application_generate_entity = application_generate_entity
|
||||
runner.conversation = conversation
|
||||
runner.app_config = app_config
|
||||
runner.model_config = MagicMock()
|
||||
runner.config = MagicMock()
|
||||
runner.queue_manager = queue_manager
|
||||
runner.message = message
|
||||
runner.user_id = "user"
|
||||
runner.model_instance = model_instance
|
||||
|
||||
runner.stream_tool_call = False
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.agent_callback = MagicMock()
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.create_agent_thought = MagicMock(return_value="thought1")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ==============================
|
||||
# Tool Call Checks
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
|
||||
# ==============================
|
||||
# Extract Tool Calls
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = ""
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
tool_call.function.arguments = json.dumps({"x": 2})
|
||||
|
||||
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_blocking_tool_calls(result)
|
||||
|
||||
assert calls == [("2", "block", {"x": 2})]
|
||||
|
||||
|
||||
# ==============================
|
||||
# System Message Initialization
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ==============================
|
||||
# Organize User Query
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
return_value=file_content,
|
||||
)
|
||||
|
||||
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
|
||||
runner.files = ["file1"]
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Clear User Prompt Images
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
|
||||
image = MagicMock()
|
||||
image.type = "image"
|
||||
image.data = "img"
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.__class__.__name__ = "UserPromptMessage"
|
||||
user_msg.content = [text, image]
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
|
||||
user_msg = UserPromptMessage(content=[text, image, document])
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
|
||||
assert result[0].content == "hello\n[image]\n[file]"
|
||||
|
||||
|
||||
# ==============================
|
||||
# Run Method Tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def flaky_dumps(obj, *args, **kwargs):
|
||||
if kwargs.get("ensure_ascii") is False:
|
||||
return real_dumps(obj, *args, **kwargs)
|
||||
raise TypeError("boom")
|
||||
|
||||
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "missing"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
tool_instance = MagicMock()
|
||||
prompt_tool = MagicMock()
|
||||
prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
|
||||
|
||||
tool_invoke_meta = MagicMock()
|
||||
tool_invoke_meta.to_dict.return_value = {"ok": True}
|
||||
mocker.patch(
|
||||
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], tool_invoke_meta),
|
||||
)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
assert any(
|
||||
isinstance(call.args[0], QueueMessageFileEvent)
|
||||
and call.args[0].message_file_id == "file1"
|
||||
and call.args[1] == PublishFrom.APPLICATION_MANAGER
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query"))
|
||||
@@ -1,324 +0,0 @@
|
||||
"""Unit tests for core.agent.plugin_entities.
|
||||
|
||||
Covers entities such as AgentFeature, AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on
|
||||
Pydantic ValidationError behavior and pytest fixtures for validation and
|
||||
mocking; ensure entity invariants and validation rules remain stable.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agent.plugin_entities import (
|
||||
AgentFeature,
|
||||
AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity,
|
||||
AgentStrategyIdentity,
|
||||
AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity,
|
||||
AgentStrategyProviderIdentity,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
|
||||
|
||||
# =========================================================
|
||||
# Fixtures
|
||||
# =========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyIdentity)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameterType Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameterType:
|
||||
@pytest.mark.parametrize(
|
||||
"enum_member",
|
||||
list(AgentStrategyParameter.AgentStrategyParameterType),
|
||||
)
|
||||
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
return_value="normalized",
|
||||
)
|
||||
|
||||
result = enum_member.as_normal_type()
|
||||
|
||||
mock_func.assert_called_once_with(enum_member)
|
||||
assert result == "normalized"
|
||||
|
||||
def test_as_normal_type_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
enum_member.as_normal_type()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("enum_member", "value"),
|
||||
[
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, None),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
|
||||
],
|
||||
)
|
||||
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
return_value="casted",
|
||||
)
|
||||
|
||||
result = enum_member.cast_value(value)
|
||||
|
||||
mock_func.assert_called_once_with(enum_member, value)
|
||||
assert result == "casted"
|
||||
|
||||
def test_cast_value_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
side_effect=ValueError("invalid"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
enum_member.cast_value("bad")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameter Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameter:
|
||||
def test_valid_creation_minimal(self) -> None:
|
||||
# bypass base PluginParameter required fields using model_construct
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=None,
|
||||
)
|
||||
assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
assert param.help is None
|
||||
|
||||
def test_valid_creation_with_help(self) -> None:
|
||||
help_obj = I18nObject(en_US="test")
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=help_obj,
|
||||
)
|
||||
assert param.help == help_obj
|
||||
|
||||
@pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}])
|
||||
def test_invalid_type_raises_validation_error(self, invalid_type) -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y"))
|
||||
|
||||
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
|
||||
|
||||
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
return_value="frontend",
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
result = param.init_frontend_parameter("value")
|
||||
|
||||
mock_func.assert_called_once_with(param, param.type, "value")
|
||||
assert result == "frontend"
|
||||
|
||||
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
side_effect=RuntimeError("error"),
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
param.init_frontend_parameter("value")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyProviderEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyProviderEntity:
|
||||
def test_creation_with_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="plugin-123",
|
||||
)
|
||||
assert entity.plugin_id == "plugin-123"
|
||||
|
||||
def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="",
|
||||
)
|
||||
assert entity.plugin_id == ""
|
||||
|
||||
def test_creation_without_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(identity=mock_provider_identity)
|
||||
assert entity.plugin_id is None
|
||||
|
||||
def test_invalid_identity_raises(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyProviderEntity(identity="invalid")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyEntity:
|
||||
def test_parameters_default_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_none_converted_to_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=None,
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_preserved(self, mock_identity) -> None:
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[param],
|
||||
)
|
||||
assert entity.parameters == [param]
|
||||
|
||||
def test_invalid_parameters_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters="invalid",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"features",
|
||||
[
|
||||
None,
|
||||
[],
|
||||
[AgentFeature.HISTORY_MESSAGES],
|
||||
],
|
||||
)
|
||||
def test_features_valid(self, mock_identity, features) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features=features,
|
||||
)
|
||||
assert entity.features == features
|
||||
|
||||
def test_invalid_features_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features="invalid",
|
||||
)
|
||||
|
||||
def test_output_schema_and_meta_version(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
output_schema={"type": "object"},
|
||||
meta_version="v1",
|
||||
)
|
||||
assert entity.output_schema == {"type": "object"}
|
||||
assert entity.meta_version == "v1"
|
||||
|
||||
def test_missing_required_fields_raise(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(identity=mock_identity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentProviderEntityWithPlugin Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentProviderEntityWithPlugin:
|
||||
def test_default_strategies_empty(self, mock_provider_identity) -> None:
|
||||
entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity)
|
||||
assert entity.strategies == []
|
||||
|
||||
def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None:
|
||||
strategy = AgentStrategyEntity.model_construct(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
entity = AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies=[strategy],
|
||||
)
|
||||
assert entity.strategies == [strategy]
|
||||
|
||||
def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies="invalid",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# Inheritance Smoke Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestInheritanceBehavior:
|
||||
def test_agent_strategy_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyIdentity, ToolIdentity)
|
||||
|
||||
def test_agent_strategy_provider_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)
|
||||
@@ -1,75 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAdvancedChatAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow)
|
||||
|
||||
assert app_config.workflow_id == "wf-1"
|
||||
assert app_config.app_mode == AppMode.ADVANCED_CHAT
|
||||
|
||||
def test_config_validate_filters_keys(self):
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = kwargs.get("config") if kwargs else args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 7),
|
||||
),
|
||||
):
|
||||
filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={})
|
||||
|
||||
assert filtered["file_upload"] == 1
|
||||
assert filtered["opening_statement"] == 2
|
||||
assert filtered["suggested_questions_after_answer"] == 3
|
||||
assert filtered["speech_to_text"] == 4
|
||||
assert filtered["text_to_speech"] == 5
|
||||
assert filtered["retriever_resource"] == 6
|
||||
assert filtered["sensitive_word_avoidance"] == 7
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,96 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateResponseConverter:
|
||||
def test_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_stream_simple_response_includes_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_start,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_finish,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert converted[0] == "ping"
|
||||
assert converted[1]["event"] == "node_started"
|
||||
assert converted[2]["event"] == "node_finished"
|
||||
assert converted[3]["event"] == "error"
|
||||
@@ -1,600 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.enums import MessageStatus
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _make_pipeline():
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
|
||||
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_ensure_workflow_initialized_raises(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow run not initialized"):
|
||||
pipeline._ensure_workflow_initialized()
|
||||
|
||||
def test_to_blocking_response_returns_message_end(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.answer = "done"
|
||||
|
||||
def _gen():
|
||||
yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"})
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.answer == "done"
|
||||
assert response.data.metadata == {"k": "v"}
|
||||
|
||||
def test_handle_text_chunk_event_updates_state(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager = SimpleNamespace(
|
||||
message_to_stream_response=lambda **kwargs: MessageEndStreamResponse(
|
||||
task_id="task", id="message-id", metadata={}
|
||||
)
|
||||
)
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=None)
|
||||
|
||||
responses = list(pipeline._handle_text_chunk_event(event))
|
||||
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert responses
|
||||
|
||||
def test_listen_audio_msg_returns_audio_stream(self):
|
||||
pipeline = _make_pipeline()
|
||||
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
|
||||
|
||||
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
|
||||
|
||||
assert isinstance(response, MessageAudioStreamResponse)
|
||||
|
||||
def test_handle_ping_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
|
||||
|
||||
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
|
||||
|
||||
assert isinstance(responses[0], PingStreamResponse)
|
||||
|
||||
def test_handle_error_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
pipeline._database_session = _fake_session
|
||||
|
||||
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
|
||||
|
||||
assert isinstance(responses[0], ValueError)
|
||||
|
||||
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
|
||||
|
||||
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
|
||||
|
||||
assert pipeline._workflow_run_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
def test_message_end_to_stream_response_strips_annotation_reply(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id="ann",
|
||||
account=AnnotationReplyAccount(id="acc", name="acc"),
|
||||
)
|
||||
|
||||
response = pipeline._message_end_to_stream_response()
|
||||
|
||||
assert "annotation_reply" not in response.metadata
|
||||
|
||||
def test_handle_output_moderation_chunk_publishes_stop(self):
|
||||
pipeline = _make_pipeline()
|
||||
events: list[object] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return True
|
||||
|
||||
def get_final_output(self):
|
||||
return "final"
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
pipeline._base_task_pipeline.queue_manager = SimpleNamespace(
|
||||
publish=lambda event, pub_from: events.append(event)
|
||||
)
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("ignored")
|
||||
|
||||
assert result is True
|
||||
assert pipeline._task_state.answer == "final"
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in events)
|
||||
assert any(isinstance(event, QueueStopEvent) for event in events)
|
||||
|
||||
def test_handle_node_succeeded_event_records_files(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [
|
||||
{"type": "file", "transfer_method": "local"}
|
||||
]
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
event = SimpleNamespace(
|
||||
node_type=NodeType.ANSWER,
|
||||
outputs={"k": "v"},
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
assert responses == ["done"]
|
||||
assert pipeline._recorded_files
|
||||
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
|
||||
"iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
|
||||
"iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
|
||||
|
||||
iter_start = QueueIterationStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_next = QueueIterationNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_done = QueueIterationCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_start = QueueLoopStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_next = QueueLoopNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_done = QueueLoopCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"]
|
||||
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"]
|
||||
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"]
|
||||
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"]
|
||||
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
|
||||
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
|
||||
|
||||
def test_workflow_finish_handlers(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"]
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: None
|
||||
pipeline._save_message = lambda **kwargs: None
|
||||
pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id")
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace(scalar=lambda *args, **kwargs: None)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={})))
|
||||
assert len(succeeded_responses) == 2
|
||||
assert isinstance(succeeded_responses[0], MessageEndStreamResponse)
|
||||
assert succeeded_responses[1] == "finish"
|
||||
|
||||
partial_success_responses = list(
|
||||
pipeline._handle_workflow_partial_success_event(
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
)
|
||||
)
|
||||
assert len(partial_success_responses) == 2
|
||||
assert isinstance(partial_success_responses[0], MessageEndStreamResponse)
|
||||
assert partial_success_responses[1] == "finish"
|
||||
assert (
|
||||
list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0]
|
||||
== "finish"
|
||||
)
|
||||
assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [
|
||||
"pause"
|
||||
]
|
||||
|
||||
def test_node_failure_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
failed_event = QueueNodeFailedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
exc_event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"]
|
||||
assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"]
|
||||
|
||||
def test_handle_text_chunk_event_tracks_streaming_metrics(self):
|
||||
pipeline = _make_pipeline()
|
||||
published: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
published.append(message)
|
||||
|
||||
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
|
||||
)
|
||||
|
||||
assert responses == ["chunk"]
|
||||
assert pipeline._task_state.is_streaming_response is True
|
||||
assert pipeline._task_state.first_token_time is not None
|
||||
assert pipeline._task_state.last_token_time is not None
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert published == [queue_message]
|
||||
|
||||
def test_handle_output_moderation_chunk_appends_token(self):
|
||||
pipeline = _make_pipeline()
|
||||
seen: list[str] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return False
|
||||
|
||||
def append_new_token(self, text):
|
||||
seen.append(text)
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("token")
|
||||
|
||||
assert result is False
|
||||
assert seen == ["token"]
|
||||
|
||||
def test_handle_retriever_and_annotation_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"retriever": 0, "annotation": 0}
|
||||
|
||||
def _hit_retriever(event):
|
||||
calls["retriever"] += 1
|
||||
|
||||
def _hit_annotation(event):
|
||||
calls["annotation"] += 1
|
||||
|
||||
pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever
|
||||
pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation
|
||||
|
||||
retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[])
|
||||
annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann")
|
||||
|
||||
assert list(pipeline._handle_retriever_resources_event(retriever_event)) == []
|
||||
assert list(pipeline._handle_annotation_reply_event(annotation_event)) == []
|
||||
assert calls == {"retriever": 1, "annotation": 1}
|
||||
|
||||
def test_handle_message_replace_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
|
||||
event = QueueMessageReplaceEvent(
|
||||
text="new",
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_message_replace_event(event)) == ["replace"]
|
||||
|
||||
def test_handle_human_input_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
persisted: list[str] = []
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved")
|
||||
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
|
||||
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
|
||||
|
||||
filled_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
action_id="action",
|
||||
action_text="action",
|
||||
)
|
||||
timeout_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
expiration_time=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
|
||||
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
|
||||
assert persisted == ["saved"]
|
||||
|
||||
def test_save_message_strips_markdown_and_sets_usage(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._recorded_files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote",
|
||||
"remote_url": "http://example.com/file.png",
|
||||
"related_id": "file-id",
|
||||
}
|
||||
]
|
||||
pipeline._task_state.answer = " hello"
|
||||
pipeline._task_state.is_streaming_response = True
|
||||
pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1
|
||||
pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
status=MessageStatus.PAUSED,
|
||||
answer="",
|
||||
updated_at=None,
|
||||
provider_response_latency=None,
|
||||
message_tokens=None,
|
||||
message_unit_price=None,
|
||||
message_price_unit=None,
|
||||
answer_tokens=None,
|
||||
answer_unit_price=None,
|
||||
answer_price_unit=None,
|
||||
total_price=None,
|
||||
currency=None,
|
||||
message_metadata=None,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_account_id=None,
|
||||
from_end_user_id="end-user",
|
||||
)
|
||||
|
||||
class _Session:
|
||||
def scalar(self, *args, **kwargs):
|
||||
return message
|
||||
|
||||
def add_all(self, items):
|
||||
self.items = items
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state)
|
||||
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
assert message.answer == "hello"
|
||||
assert message.message_metadata
|
||||
|
||||
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)))
|
||||
|
||||
assert responses == ["end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent()))
|
||||
|
||||
assert responses == ["replace", "end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_dispatch_event_handles_node_exception(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
|
||||
pipeline._save_output_for_event = lambda *args, **kwargs: None
|
||||
|
||||
event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._dispatch_event(event)) == ["failed"]
|
||||
@@ -1,302 +0,0 @@
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.agent_chat.app_config_manager import (
|
||||
AgentChatAppConfigManager,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerGetAppConfig:
|
||||
def test_get_app_config_override_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"ignored": True}
|
||||
|
||||
override_config = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.variables == "variables"
|
||||
assert result.external_data_variables == "external"
|
||||
|
||||
def test_get_app_config_conversation_specific(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
conversation = mocker.MagicMock()
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == app_model_config.to_dict.return_value
|
||||
assert result.app_model_config_from.value == "conversation-specific-config"
|
||||
|
||||
def test_get_app_config_latest_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from.value == "app-latest-config"
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerConfigValidate:
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {},
|
||||
"user_input_form": {},
|
||||
"file_upload": {},
|
||||
"prompt_template": {},
|
||||
"agent_mode": {},
|
||||
"opening_statement": {},
|
||||
"suggested_questions_after_answer": {},
|
||||
"speech_to_text": {},
|
||||
"text_to_speech": {},
|
||||
"retriever_resource": {},
|
||||
"dataset": {},
|
||||
"moderation": {},
|
||||
"extra": "value",
|
||||
}
|
||||
|
||||
def return_with_key(key):
|
||||
return config, [key]
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("model"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("file_upload"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda app_mode, cfg: return_with_key("prompt_template"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
AgentChatAppConfigManager,
|
||||
"validate_agent_mode_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("opening_statement"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("speech_to_text"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("text_to_speech"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("retriever_resource"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("moderation"),
|
||||
)
|
||||
|
||||
filtered = AgentChatAppConfigManager.config_validate("tenant", config)
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"prompt_template",
|
||||
"agent_mode",
|
||||
"opening_statement",
|
||||
"suggested_questions_after_answer",
|
||||
"speech_to_text",
|
||||
"text_to_speech",
|
||||
"retriever_resource",
|
||||
"dataset",
|
||||
"moderation",
|
||||
}
|
||||
assert "extra" not in filtered
|
||||
|
||||
|
||||
class TestValidateAgentModeAndSetDefaults:
|
||||
def test_defaults_when_missing(self):
|
||||
config = {}
|
||||
updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert "agent_mode" in updated
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
assert keys == ["agent_mode"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_mode",
|
||||
["invalid", 123],
|
||||
)
|
||||
def test_agent_mode_type_validation(self, agent_mode):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode})
|
||||
|
||||
def test_agent_mode_empty_list_defaults(self):
|
||||
config = {"agent_mode": []}
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
|
||||
def test_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}})
|
||||
|
||||
def test_strategy_must_be_valid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}}
|
||||
)
|
||||
|
||||
def test_tools_must_be_list(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_requires_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_must_be_uuid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_not_exists(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=False,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"])
|
||||
def test_new_style_tool_requires_fields(self, missing_key):
|
||||
tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"}
|
||||
tool.pop(missing_key, None)
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
|
||||
)
|
||||
|
||||
def test_valid_old_and_new_style_tools(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=True,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": PlanningStrategy.ROUTER.value,
|
||||
"tools": [
|
||||
{"dataset": {"id": dataset_id}},
|
||||
{
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "p1",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"][1]["enabled"] is False
|
||||
@@ -1,296 +0,0 @@
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id):
|
||||
self.id = user_id
|
||||
self.session_id = f"session-{user_id}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = AgentChatAppGenerator()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.current_app",
|
||||
new=mocker.MagicMock(_get_current_object=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx")
|
||||
return gen
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorGenerate:
|
||||
def test_generate_rejects_blocking_mode(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
|
||||
|
||||
def test_generate_requires_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
|
||||
|
||||
def test_generate_rejects_non_string_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": 123, "inputs": {}},
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
def test_generate_override_requires_debugger(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_success_with_debugger_override(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate",
|
||||
return_value={"validated": True},
|
||||
)
|
||||
app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[])
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ConversationService.get_conversation",
|
||||
return_value=mocker.MagicMock(id="conv"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=queue_manager,
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {
|
||||
"query": "hello",
|
||||
"inputs": {"name": "world"},
|
||||
"conversation_id": "conv",
|
||||
"model_config": {"model": {"provider": "p"}},
|
||||
"files": [{"id": "f1"}],
|
||||
}
|
||||
|
||||
result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
thread_obj.start.assert_called_once()
|
||||
|
||||
def test_generate_without_file_config(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {"query": "hello", "inputs": {"name": "world"}}
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorWorker:
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_context(self, mocker):
|
||||
@contextlib.contextmanager
|
||||
def ctx_manager(*args, **kwargs):
|
||||
yield
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
|
||||
|
||||
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = GenerateTaskStoppedError()
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
InvokeAuthorizationError("bad"),
|
||||
ValidationError.from_exception_data("TestModel", []),
|
||||
ValueError("bad"),
|
||||
Exception("bad"),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_publishes_errors(self, generator, mocker, error):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = error
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called
|
||||
|
||||
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = ValueError("bad")
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True))
|
||||
logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
logger.exception.assert_called_once()
|
||||
@@ -1,413 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return AgentChatAppRunner()
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_moderation_error_direct_output(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
user_id="user",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
annotation = mocker.MagicMock(id="anno", content="answer")
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation)
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
queue_manager.publish.assert_called_once()
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=True)
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_model_schema_missing(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = None
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_runner"),
|
||||
[
|
||||
(LLMMode.CHAT, "CotChatAgentRunner"),
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: mode}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = [ModelFeature.TOOL_CALL]
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_conversation_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_message_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
@@ -1,162 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterBlocking:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"a": 1},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"a": 1}
|
||||
|
||||
def test_convert_blocking_simple_response_with_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
|
||||
def test_convert_blocking_simple_response_with_non_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse.model_construct(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterStream:
|
||||
def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _gen():
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=2,
|
||||
stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=3,
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="m1",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
"extra": "ignored",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=4,
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")),
|
||||
)
|
||||
|
||||
return _gen()
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
assert items[1]["event"] == "message"
|
||||
assert "answer" in items[1]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert items[3]["event"] == "error"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
# Assert the message event structure and content at items[1]
|
||||
assert items[1]["event"] == "message"
|
||||
assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert "metadata" in items[2]
|
||||
metadata = items[2]["metadata"]
|
||||
assert "annotation_reply" not in metadata
|
||||
assert "usage" not in metadata
|
||||
assert metadata["retriever_resources"] == [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
assert items[3]["event"] == "error"
|
||||
@@ -1,113 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestChatAppConfigManager:
|
||||
def test_get_app_config_uses_override_dict(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
|
||||
app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"})
|
||||
override = {"model": "override"}
|
||||
|
||||
model_entity = ModelConfigEntity(provider="p", model="m")
|
||||
prompt_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hi",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])),
|
||||
):
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override,
|
||||
)
|
||||
|
||||
assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert app_config.app_model_config_dict == override
|
||||
assert app_config.app_mode == AppMode.CHAT
|
||||
|
||||
def test_config_validate_filters_related_keys(self):
|
||||
config = {"extra": 1}
|
||||
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("model", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("inputs", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("prompt", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("dataset", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 7),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 8),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 9),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 10),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 11),
|
||||
),
|
||||
):
|
||||
filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config)
|
||||
|
||||
assert filtered["model"] == 1
|
||||
assert filtered["inputs"] == 2
|
||||
assert filtered["file_upload"] == 3
|
||||
assert filtered["prompt"] == 4
|
||||
assert filtered["dataset"] == 5
|
||||
assert filtered["opening_statement"] == 6
|
||||
assert filtered["suggested_questions_after_answer"] == 7
|
||||
assert filtered["speech_to_text"] == 8
|
||||
assert filtered["text_to_speech"] == 9
|
||||
assert filtered["retriever_resource"] == 10
|
||||
assert filtered["sensitive_word_avoidance"] == 11
|
||||
@@ -1,280 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class DummyQueueManager:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.published = []
|
||||
|
||||
def publish_error(self, error, pub_from):
|
||||
self.published.append((error, pub_from))
|
||||
|
||||
def publish(self, event, pub_from):
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
class TestChatAppGenerator:
|
||||
def test_generate_requires_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_rejects_non_string_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
args={"query": 1, "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_debugger_overrides_model_config(self):
|
||||
generator = ChatAppGenerator()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
user = SimpleNamespace(id="user-1", session_id="session-1")
|
||||
args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}),
|
||||
patch(
|
||||
"core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config",
|
||||
return_value=SimpleNamespace(
|
||||
variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT
|
||||
),
|
||||
),
|
||||
patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity),
|
||||
patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager),
|
||||
patch(
|
||||
"core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True}
|
||||
),
|
||||
patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})),
|
||||
patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}),
|
||||
patch.object(
|
||||
ChatAppGenerator,
|
||||
"_init_generate_records",
|
||||
return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")),
|
||||
),
|
||||
patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}),
|
||||
patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f),
|
||||
patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread,
|
||||
):
|
||||
mock_thread.return_value.start.return_value = None
|
||||
result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_generate_rejects_model_config_override_for_non_debugger(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
with (
|
||||
patch.object(
|
||||
ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})
|
||||
),
|
||||
):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value),
|
||||
user=SimpleNamespace(id="u1", session_id="s1"),
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_worker_handles_exceptions(self):
|
||||
generator = ChatAppGenerator()
|
||||
queue_manager = DummyQueueManager()
|
||||
entity = DummyGenerateEntity(task_id="t1", user_id="u1")
|
||||
|
||||
with (
|
||||
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
|
||||
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()),
|
||||
patch("core.app.apps.chat.app_generator.db.session.close"),
|
||||
):
|
||||
generator._generate_worker(
|
||||
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
|
||||
application_generate_entity=entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
assert queue_manager.published
|
||||
|
||||
with (
|
||||
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
|
||||
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()),
|
||||
patch("core.app.apps.chat.app_generator.db.session.close"),
|
||||
):
|
||||
generator._generate_worker(
|
||||
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
|
||||
application_generate_entity=entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
class TestChatAppRunner:
|
||||
def test_run_raises_when_app_missing(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[]
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
def test_run_moderation_error_direct_output(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
mock_direct.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
annotation = SimpleNamespace(id="ann-1", content="answer")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
):
|
||||
queue_manager = DummyQueueManager()
|
||||
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published)
|
||||
mock_direct.assert_called_once()
|
||||
|
||||
def test_run_returns_when_hosting_moderation_blocks(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
@@ -1,65 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestChatAppGenerateResponseConverter:
|
||||
def test_convert_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_convert_stream_responses(self):
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
assert full[0] == "ping"
|
||||
assert full[1]["event"] == "message"
|
||||
assert full[2]["event"] == "error"
|
||||
|
||||
simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert simple[0] == "ping"
|
||||
assert simple[-1]["event"] == "message_end"
|
||||
@@ -1,162 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.completion.app_runner as module
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CompletionAppRunner()
|
||||
|
||||
|
||||
def _build_app_config(dataset=None, external_tools=None, additional_features=None):
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.tenant_id = "tenant"
|
||||
app_config.prompt_template = MagicMock()
|
||||
app_config.dataset = dataset
|
||||
app_config.external_data_variables = external_tools or []
|
||||
app_config.additional_features = additional_features
|
||||
app_config.app_model_config_dict = {"file_upload": {"enabled": True}}
|
||||
return app_config
|
||||
|
||||
|
||||
def _build_generate_entity(app_config, file_upload_config=None):
|
||||
model_conf = MagicMock(
|
||||
provider_model_bundle="bundle",
|
||||
model="model",
|
||||
parameters={"max_tokens": 10},
|
||||
stop=["stop"],
|
||||
)
|
||||
return SimpleNamespace(
|
||||
app_config=app_config,
|
||||
model_conf=model_conf,
|
||||
inputs={"qvar": "query_from_input"},
|
||||
query="original_query",
|
||||
files=[],
|
||||
file_upload_config=file_upload_config,
|
||||
stream=True,
|
||||
user_id="user",
|
||||
invoke_from=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionAppRunner:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
|
||||
def test_run_moderation_error_outputs_direct(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked"))
|
||||
runner.direct_output = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_hosting_moderation_stops(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_dataset_and_external_tools_flow(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
session.close = MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
retrieve_config = MagicMock(query_variable="qvar")
|
||||
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
|
||||
additional_features = MagicMock(show_retrieve_source=True)
|
||||
app_config = _build_app_config(
|
||||
dataset=dataset_config,
|
||||
external_tools=["tool"],
|
||||
additional_features=additional_features,
|
||||
)
|
||||
|
||||
file_upload_config = MagicMock()
|
||||
file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])])
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs)
|
||||
runner.check_hosting_moderation = MagicMock(return_value=False)
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
dataset_retrieval = MagicMock()
|
||||
dataset_retrieval.retrieve.return_value = ("ctx", ["file1"])
|
||||
mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval)
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.invoke_llm.return_value = "invoke_result"
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
|
||||
dataset_retrieval.retrieve.assert_called_once()
|
||||
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_uses_low_image_detail_default(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
assert (
|
||||
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
|
||||
== ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
@@ -1,122 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.completion.app_config_manager as module
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestCompletionAppConfigManager:
|
||||
def test_get_app_config_with_override(self, mocker):
|
||||
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
|
||||
app_model_config = MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
|
||||
override_config = {"model": {"provider": "override"}}
|
||||
|
||||
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
|
||||
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
|
||||
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
|
||||
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
|
||||
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
|
||||
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"]))
|
||||
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.variables == ["v1"]
|
||||
assert result.external_data_variables == ["ext1"]
|
||||
assert result.app_mode == AppMode.COMPLETION
|
||||
|
||||
def test_get_app_config_without_override_uses_model_config(self, mocker):
|
||||
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
|
||||
app_model_config = MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
|
||||
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
|
||||
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
|
||||
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
|
||||
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
|
||||
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
|
||||
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], []))
|
||||
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
|
||||
assert result.app_model_config_dict == {"model": {"provider": "x"}}
|
||||
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {"provider": "x"},
|
||||
"variables": ["v"],
|
||||
"file_upload": {"enabled": True},
|
||||
"prompt": {"template": "t"},
|
||||
"dataset": {"enabled": True},
|
||||
"tts": {"enabled": True},
|
||||
"more_like_this": {"enabled": True},
|
||||
"moderation": {"enabled": True},
|
||||
"extra": "drop",
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
module.ModelConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["model"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.BasicVariablesConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["variables"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.FileUploadConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["file_upload"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PromptTemplateConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["prompt"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DatasetConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["dataset"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.TextToSpeechConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["tts"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.MoreLikeThisConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["more_like_this"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.SensitiveWordAvoidanceConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["moderation"]),
|
||||
)
|
||||
|
||||
filtered = CompletionAppConfigManager.config_validate("tenant", config)
|
||||
|
||||
assert "extra" not in filtered
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"variables",
|
||||
"file_upload",
|
||||
"prompt",
|
||||
"dataset",
|
||||
"tts",
|
||||
"more_like_this",
|
||||
"moderation",
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import core.app.apps.completion.app_generator as module
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = CompletionAppGenerator()
|
||||
|
||||
mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn)
|
||||
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app)))
|
||||
|
||||
thread = MagicMock()
|
||||
mocker.patch.object(module.threading, "Thread", return_value=thread)
|
||||
|
||||
mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock())
|
||||
mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock())
|
||||
mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def _build_app_model():
|
||||
return MagicMock(tenant_id="tenant", id="app1", mode="completion")
|
||||
|
||||
|
||||
def _build_user():
|
||||
return MagicMock(id="user", session_id="session")
|
||||
|
||||
|
||||
def _build_app_model_config():
|
||||
config = MagicMock(id="cfg")
|
||||
config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
return config
|
||||
|
||||
|
||||
class TestCompletionAppGenerator:
|
||||
def test_generate_invalid_query_type(self, generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": 123, "inputs": {}, "files": []},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
def test_generate_override_not_debugger(self, generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {}, "files": [], "model_config": {}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_success_no_file_config(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings")
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
conversation = MagicMock(id="conv", mode="completion")
|
||||
message = MagicMock(id="msg")
|
||||
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {"a": 1}, "files": []},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
module.file_factory.build_from_mappings.assert_not_called()
|
||||
|
||||
def test_generate_success_with_files(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
|
||||
file_extra_config = MagicMock()
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
conversation = MagicMock(id="conv", mode="completion")
|
||||
message = MagicMock(id="msg")
|
||||
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
module.file_factory.build_from_mappings.assert_called_once()
|
||||
|
||||
def test_generate_override_model_config_debugger(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
|
||||
override_config = {"model": {"provider": "override"}}
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config)
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
get_app_config = mocker.patch.object(
|
||||
module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_init_generate_records",
|
||||
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
|
||||
)
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {}, "files": [], "model_config": override_config},
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert get_app_config.call_args.kwargs["override_config_dict"] == override_config
|
||||
|
||||
def test_generate_more_like_this_message_not_found(self, generator, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=_build_app_model(),
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_disabled(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False})
|
||||
|
||||
message = MagicMock()
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MoreLikeThisDisabledError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_app_model_config_missing(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = None
|
||||
|
||||
message = MagicMock()
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MoreLikeThisDisabledError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_message_config_none(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
|
||||
|
||||
message = MagicMock(app_model_config=None)
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_success(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
|
||||
|
||||
message = MagicMock()
|
||||
message.message_files = [{"id": "f"}]
|
||||
message.inputs = {"a": 1}
|
||||
message.query = "q"
|
||||
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.to_dict.return_value = {
|
||||
"model": {"completion_params": {"temperature": 0.1}},
|
||||
"file_upload": {"enabled": True},
|
||||
}
|
||||
message.app_model_config = app_model_config
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
file_extra_config = MagicMock()
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
get_app_config = mocker.patch.object(
|
||||
module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_init_generate_records",
|
||||
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
|
||||
)
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
override_dict = get_app_config.call_args.kwargs["override_config_dict"]
|
||||
assert override_dict["model"]["completion_params"]["temperature"] == 0.9
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "should_publish"),
|
||||
[
|
||||
(GenerateTaskStoppedError(), False),
|
||||
(InvokeAuthorizationError("bad"), True),
|
||||
(
|
||||
ValidationError.from_exception_data(
|
||||
"Model",
|
||||
[{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}],
|
||||
),
|
||||
True,
|
||||
),
|
||||
(ValueError("bad"), True),
|
||||
(RuntimeError("boom"), True),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_error_handling(self, generator, mocker, error, should_publish):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(generator, "_get_message", return_value=MagicMock())
|
||||
|
||||
runner_instance = MagicMock()
|
||||
runner_instance.run.side_effect = error
|
||||
mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called is should_publish
|
||||
@@ -1,153 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionAppGenerateResponseConverter:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = CompletionAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"k": "v"},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["task_id"] == "task"
|
||||
assert result["message_id"] == "msg"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"k": "v"}
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_simplified(self):
|
||||
metadata = {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "c",
|
||||
"summary": "sum",
|
||||
"extra": "x",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"a": 1},
|
||||
"usage": {"t": 2},
|
||||
}
|
||||
blocking = CompletionAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata=metadata,
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
|
||||
assert "extra" not in result["metadata"]["retriever_resources"][0]
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_not_dict(self):
|
||||
data = CompletionAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
)
|
||||
blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
message_id="m",
|
||||
created_at=1,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
message_id="m",
|
||||
created_at=2,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"),
|
||||
message_id="m",
|
||||
created_at=3,
|
||||
)
|
||||
|
||||
result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "error"
|
||||
assert result[1]["code"] == "invalid_param"
|
||||
assert result[2]["event"] == "message"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
message_id="m",
|
||||
created_at=1,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="end",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "c",
|
||||
"summary": "sum",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"a": 1},
|
||||
"usage": {"t": 2},
|
||||
},
|
||||
),
|
||||
message_id="m",
|
||||
created_at=2,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
message_id="m",
|
||||
created_at=3,
|
||||
)
|
||||
|
||||
result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "message_end"
|
||||
assert "annotation_reply" not in result[1]["metadata"]
|
||||
assert "usage" not in result[1]["metadata"]
|
||||
assert result[2]["event"] == "error"
|
||||
@@ -1,55 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.pipeline.pipeline_config_manager as module
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_pipeline_config(mocker):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe1")
|
||||
workflow = MagicMock(id="wf1")
|
||||
|
||||
mocker.patch.object(
|
||||
module.WorkflowVariablesConfigManager,
|
||||
"convert_rag_pipeline_variable",
|
||||
return_value=["var1"],
|
||||
)
|
||||
mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start")
|
||||
|
||||
assert result.tenant_id == "tenant"
|
||||
assert result.app_id == "pipe1"
|
||||
assert result.workflow_id == "wf1"
|
||||
assert result.app_mode == AppMode.RAG_PIPELINE
|
||||
assert result.rag_pipeline_variables == ["var1"]
|
||||
|
||||
|
||||
def test_config_validate_filters_related_keys(mocker):
|
||||
config = {
|
||||
"file_upload": {"enabled": True},
|
||||
"tts": {"enabled": True},
|
||||
"moderation": {"enabled": True},
|
||||
"extra": "drop",
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
module.FileUploadConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["file_upload"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.TextToSpeechConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["tts"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.SensitiveWordAvoidanceConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["moderation"]),
|
||||
)
|
||||
|
||||
filtered = PipelineConfigManager.config_validate("tenant", config)
|
||||
|
||||
assert set(filtered.keys()) == {"file_upload", "tts", "moderation"}
|
||||
@@ -1,111 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
def test_convert_blocking_full_and_simple_response():
|
||||
blocking = WorkflowAppBlockingResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="id",
|
||||
workflow_id="wf",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"k": "v"},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=10,
|
||||
total_steps=1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert full == simple
|
||||
assert full["workflow_run_id"] == "run"
|
||||
assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_convert_stream_full_response():
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
workflow_run_id="run",
|
||||
)
|
||||
yield WorkflowAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
workflow_run_id="run",
|
||||
)
|
||||
|
||||
result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "error"
|
||||
assert result[1]["code"] == "invalid_param"
|
||||
|
||||
|
||||
def test_convert_stream_simple_response_node_ignore_details():
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t",
|
||||
workflow_run_id="run",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="nid",
|
||||
node_id="node",
|
||||
node_type="type",
|
||||
title="Title",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"a": 1},
|
||||
inputs_truncated=False,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t",
|
||||
workflow_run_id="run",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="nid",
|
||||
node_id="node",
|
||||
node_type="type",
|
||||
title="Title",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"a": 1},
|
||||
inputs_truncated=False,
|
||||
process_data=None,
|
||||
process_data_truncated=False,
|
||||
outputs={"b": 2},
|
||||
outputs_truncated=False,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
execution_metadata=None,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
files=[],
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run")
|
||||
yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run")
|
||||
|
||||
result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
|
||||
assert result[0]["event"] == "node_started"
|
||||
assert result[0]["data"]["inputs"] is None
|
||||
assert result[1]["event"] == "node_finished"
|
||||
assert result[1]["data"]["inputs"] is None
|
||||
@@ -1,699 +0,0 @@
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_generator as module
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
|
||||
|
||||
class FakeRagPipelineGenerateEntity(SimpleNamespace):
|
||||
class SingleIterationRunEntity(SimpleNamespace):
|
||||
pass
|
||||
|
||||
class SingleLoopRunEntity(SimpleNamespace):
|
||||
pass
|
||||
|
||||
def model_dump(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = module.PipelineGenerator()
|
||||
|
||||
mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity)
|
||||
mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs)
|
||||
mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock()))
|
||||
mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock()))
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def _build_pipeline_dataset():
|
||||
return SimpleNamespace(
|
||||
id="ds",
|
||||
name="dataset",
|
||||
description="desc",
|
||||
chunk_structure="chunk",
|
||||
built_in_field_enabled=True,
|
||||
tenant_id="tenant",
|
||||
)
|
||||
|
||||
|
||||
def _build_pipeline():
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
pipeline.retrieve_dataset.return_value = _build_pipeline_dataset()
|
||||
return pipeline
|
||||
|
||||
|
||||
def _build_workflow():
|
||||
return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant")
|
||||
|
||||
|
||||
def _build_user():
|
||||
return MagicMock(id="user", name="User", session_id="session")
|
||||
|
||||
|
||||
def _build_args():
|
||||
return {
|
||||
"inputs": {"k": "v"},
|
||||
"start_node_id": "start",
|
||||
"datasource_type": DatasourceProviderType.LOCAL_FILE.value,
|
||||
"datasource_info_list": [{"name": "file"}],
|
||||
}
|
||||
|
||||
|
||||
def _patch_session(mocker, session):
|
||||
mocker.patch.object(module, "Session", return_value=session)
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
|
||||
def _dummy_preserve(*args, **kwargs):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.scalar = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def test_generate_dataset_missing(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=_build_workflow(),
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_debugger_calls_generate(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=[{"name": "file"}],
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
datasource_info_list = [{"name": "file1"}, {"name": "file2"}]
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=datasource_info_list,
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1)
|
||||
|
||||
document1 = SimpleNamespace(
|
||||
id="doc1",
|
||||
position=1,
|
||||
data_source_type=DatasourceProviderType.LOCAL_FILE,
|
||||
data_source_info="{}",
|
||||
name="file1",
|
||||
indexing_status="",
|
||||
error=None,
|
||||
enabled=True,
|
||||
)
|
||||
document2 = SimpleNamespace(
|
||||
id="doc2",
|
||||
position=2,
|
||||
data_source_type=DatasourceProviderType.LOCAL_FILE,
|
||||
data_source_info="{}",
|
||||
name="file2",
|
||||
indexing_status="",
|
||||
error=None,
|
||||
enabled=True,
|
||||
)
|
||||
mocker.patch.object(generator, "_build_document", side_effect=[document1, document2])
|
||||
|
||||
mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock())
|
||||
|
||||
db_session = MagicMock()
|
||||
mocker.patch.object(module.db, "session", db_session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
task_proxy = MagicMock()
|
||||
mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy)
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result["batch"]
|
||||
assert len(result["documents"]) == 2
|
||||
task_proxy.delay.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_is_retry_calls_generate(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=[{"name": "file"}],
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=True,
|
||||
is_retry=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
def test_generate_worker_handles_errors(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
application_generate_entity = FakeRagPipelineGenerateEntity(
|
||||
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
session = DummySession()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
|
||||
_patch_session(mocker, session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
runner_instance.run.side_effect = ValueError("bad")
|
||||
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
context=contextlib.nullcontext(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
application_generate_entity = FakeRagPipelineGenerateEntity(
|
||||
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
session = DummySession()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
|
||||
_patch_session(mocker, session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
context=contextlib.nullcontext(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session"
|
||||
|
||||
|
||||
def test_generate_raises_when_workflow_not_found(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=contextlib.nullcontext(),
|
||||
pipeline=_build_pipeline(),
|
||||
workflow_id="wf",
|
||||
user=_build_user(),
|
||||
application_generate_entity=FakeRagPipelineGenerateEntity(
|
||||
task_id="t",
|
||||
app_config=SimpleNamespace(app_id="pipe"),
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_success_returns_converted(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager)
|
||||
|
||||
worker_thread = MagicMock()
|
||||
mocker.patch.object(module.threading, "Thread", return_value=worker_thread)
|
||||
|
||||
mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=contextlib.nullcontext(),
|
||||
pipeline=_build_pipeline(),
|
||||
workflow_id="wf",
|
||||
user=_build_user(),
|
||||
application_generate_entity=FakeRagPipelineGenerateEntity(
|
||||
task_id="t",
|
||||
app_config=SimpleNamespace(app_id="pipe"),
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
|
||||
|
||||
def test_single_iteration_generate_validates_inputs(generator, mocker):
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(
|
||||
_build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None}
|
||||
)
|
||||
|
||||
|
||||
def test_single_iteration_generate_dataset_required(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
)
|
||||
|
||||
|
||||
def test_single_iteration_generate_success(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
|
||||
|
||||
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
|
||||
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"ok": True})
|
||||
|
||||
result = generator.single_iteration_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_single_loop_generate_success(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
|
||||
|
||||
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
|
||||
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"ok": True})
|
||||
|
||||
result = generator.single_loop_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
app_entity = FakeRagPipelineGenerateEntity(task_id="t")
|
||||
|
||||
task_pipeline = MagicMock()
|
||||
task_pipeline.process.side_effect = ValueError("I/O operation on closed file.")
|
||||
mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=app_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=MagicMock(),
|
||||
user=_build_user(),
|
||||
draft_var_saver_factory=MagicMock(),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
|
||||
class DummyDocument(SimpleNamespace):
|
||||
pass
|
||||
|
||||
mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs))
|
||||
|
||||
document = generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=True,
|
||||
datasource_type=DatasourceProviderType.LOCAL_FILE,
|
||||
datasource_info={"name": "file"},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
assert document.name == "file"
|
||||
assert document.doc_metadata
|
||||
|
||||
|
||||
def test_build_document_invalid_datasource_type(generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=False,
|
||||
datasource_type="invalid",
|
||||
datasource_info={},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
|
||||
def test_format_datasource_info_list_non_online_drive(generator):
|
||||
result = generator._format_datasource_info_list(
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
[{"name": "file"}],
|
||||
_build_pipeline(),
|
||||
_build_workflow(),
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
assert result == [{"name": "file"}]
|
||||
|
||||
|
||||
def test_format_datasource_info_list_missing_node_data(generator):
|
||||
workflow = MagicMock(graph_dict={"nodes": []})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator._format_datasource_info_list(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
[],
|
||||
_build_pipeline(),
|
||||
workflow,
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
|
||||
def test_format_datasource_info_list_online_drive_folder(generator, mocker):
|
||||
workflow = MagicMock(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"data": {
|
||||
"plugin_id": "p",
|
||||
"provider_name": "provider",
|
||||
"datasource_name": "drive",
|
||||
"credential_id": "cred",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.runtime = SimpleNamespace(credentials=None)
|
||||
runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=runtime,
|
||||
)
|
||||
mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_get_files_in_folder",
|
||||
side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}),
|
||||
)
|
||||
|
||||
result = generator._format_datasource_info_list(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
[{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}],
|
||||
_build_pipeline(),
|
||||
workflow,
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
assert result == [{"id": "f"}]
|
||||
|
||||
|
||||
def test_get_files_in_folder_recurses_and_collects(generator):
|
||||
class File:
|
||||
def __init__(self, id, name, type):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.type = type
|
||||
|
||||
class FilesPage:
|
||||
def __init__(self, files, is_truncated=False, next_page_parameters=None):
|
||||
self.files = files
|
||||
self.is_truncated = is_truncated
|
||||
self.next_page_parameters = next_page_parameters
|
||||
|
||||
class Result:
|
||||
def __init__(self, result):
|
||||
self.result = result
|
||||
|
||||
class Runtime:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def datasource_provider_type(self):
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def online_drive_browse_files(self, user_id, request, provider_type):
|
||||
self.calls.append(request.next_page_parameters)
|
||||
if request.prefix == "fd":
|
||||
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
|
||||
if request.next_page_parameters is None:
|
||||
return iter(
|
||||
[
|
||||
Result(
|
||||
[FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})]
|
||||
)
|
||||
]
|
||||
)
|
||||
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
|
||||
|
||||
runtime = Runtime()
|
||||
all_files = []
|
||||
|
||||
generator._get_files_in_folder(
|
||||
datasource_runtime=runtime,
|
||||
prefix="root",
|
||||
bucket="b",
|
||||
user_id="user",
|
||||
all_files=all_files,
|
||||
datasource_info={},
|
||||
)
|
||||
|
||||
assert {f["id"] for f in all_files} == {"f1", "f2"}
|
||||
@@ -1,57 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_queue_manager as module
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
|
||||
|
||||
def test_publish_sets_stop_listen_and_raises_on_stopped(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=True)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_stop_events_trigger_stop_listen(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=False)
|
||||
|
||||
for event in [
|
||||
QueueErrorEvent(error=ValueError("bad")),
|
||||
QueueMessageEndEvent(llm_result=LLMResult.model_construct()),
|
||||
QueueWorkflowSucceededEvent(),
|
||||
QueueWorkflowFailedEvent(error="failed", exceptions_count=1),
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1),
|
||||
]:
|
||||
manager.stop_listen.reset_mock()
|
||||
manager._publish(event, PublishFrom.TASK_PIPELINE)
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_non_stop_event_no_stop_listen(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=False)
|
||||
|
||||
non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent)
|
||||
manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE)
|
||||
manager.stop_listen.assert_not_called()
|
||||
@@ -1,297 +0,0 @@
|
||||
"""
|
||||
Unit tests for PipelineRunner behavior.
|
||||
Asserts correct event handling, error propagation, and user invocation logic.
|
||||
Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies.
|
||||
Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities.
|
||||
"""
|
||||
|
||||
"""Unit tests for PipelineRunner behavior.
|
||||
|
||||
This module validates core control-flow outcomes for
|
||||
``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph
|
||||
initialization guards, invoke-source to user-source resolution, and failed-run
|
||||
event handling. Invariants asserted here include strict graph-config
|
||||
validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing
|
||||
error paths driven by ``GraphRunFailedEvent`` through mocked collaborators.
|
||||
Primary collaborators include ``PipelineRunner``,
|
||||
``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``,
|
||||
``UserFrom``, and patched DB/runtime dependencies used by the runner.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_runner as module
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from dify_graph.graph_events import GraphRunFailedEvent
|
||||
|
||||
|
||||
def _build_app_generate_entity() -> SimpleNamespace:
|
||||
app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant")
|
||||
return SimpleNamespace(
|
||||
app_config=app_config,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
trace_manager=MagicMock(),
|
||||
inputs={"input1": "v1"},
|
||||
files=[],
|
||||
workflow_execution_id="run",
|
||||
document_id="doc",
|
||||
original_document_id=None,
|
||||
batch="batch",
|
||||
dataset_id="ds",
|
||||
datasource_type="local_file",
|
||||
datasource_info={"name": "file"},
|
||||
start_node_id="start",
|
||||
call_depth=0,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
queue_manager = MagicMock()
|
||||
variable_loader = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow_execution_repository = MagicMock()
|
||||
workflow_node_execution_repository = MagicMock()
|
||||
|
||||
return PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
|
||||
def test_get_app_id(runner):
|
||||
assert runner._get_app_id() == "pipe"
|
||||
|
||||
|
||||
def test_get_workflow_returns_workflow(mocker, runner):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
workflow = MagicMock(id="wf")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
|
||||
|
||||
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
|
||||
def test_init_rag_pipeline_graph_invalid_config(mocker, runner):
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
workflow.graph_dict = {"nodes": "bad", "edges": []}
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
workflow.graph_dict = {"nodes": [], "edges": "bad"}
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
|
||||
def test_init_rag_pipeline_graph_not_found(mocker, runner):
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []})
|
||||
mocker.patch.object(module.Graph, "init", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
|
||||
def test_update_document_status_on_failure(mocker, runner):
|
||||
document = MagicMock()
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
event = GraphRunFailedEvent(error="boom")
|
||||
|
||||
runner._update_document_status(event, document_id="doc", dataset_id="ds")
|
||||
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error == "boom"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_run_pipeline_not_found(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
app_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run()
|
||||
|
||||
|
||||
def test_run_workflow_not_initialized(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_pipeline
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
runner.get_workflow = MagicMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run()
|
||||
|
||||
|
||||
def test_run_single_iteration_path(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
app_generate_entity.single_iteration_run = MagicMock()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
|
||||
runner.get_workflow = MagicMock(
|
||||
return_value=MagicMock(
|
||||
id="wf",
|
||||
tenant_id="tenant",
|
||||
app_id="pipe",
|
||||
graph_dict={},
|
||||
type="rag-pipeline",
|
||||
version="v1",
|
||||
)
|
||||
)
|
||||
runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state"))
|
||||
runner._update_document_status = MagicMock()
|
||||
runner._handle_event = MagicMock()
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = [MagicMock()]
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
runner._prepare_single_node_execution.assert_called_once()
|
||||
runner._handle_event.assert_called()
|
||||
|
||||
|
||||
def test_run_normal_path_builds_graph(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
workflow = MagicMock(
|
||||
id="wf",
|
||||
tenant_id="tenant",
|
||||
app_id="pipe",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
environment_variables=[],
|
||||
rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}],
|
||||
type="rag-pipeline",
|
||||
version="v1",
|
||||
)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
|
||||
runner.get_workflow = MagicMock(return_value=workflow)
|
||||
runner._init_rag_pipeline_graph = MagicMock(return_value="graph")
|
||||
runner._update_document_status = MagicMock()
|
||||
runner._handle_event = MagicMock()
|
||||
|
||||
mocker.patch.object(
|
||||
module.RAGPipelineVariable,
|
||||
"model_validate",
|
||||
return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"),
|
||||
)
|
||||
mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = []
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
runner._init_rag_pipeline_graph.assert_called_once()
|
||||
@@ -1,5 +1,3 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@@ -368,132 +366,3 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBaseAppGeneratorExtras:
|
||||
def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="file",
|
||||
label="file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
allowed_file_types=[],
|
||||
allowed_file_extensions=[],
|
||||
allowed_file_upload_methods=[],
|
||||
),
|
||||
VariableEntity(
|
||||
variable="file_list",
|
||||
label="file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
allowed_file_types=[],
|
||||
allowed_file_extensions=[],
|
||||
allowed_file_upload_methods=[],
|
||||
),
|
||||
VariableEntity(
|
||||
variable="json",
|
||||
label="json",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
|
||||
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
|
||||
lambda mappings, tenant_id, config: ["file-1", "file-2"],
|
||||
)
|
||||
|
||||
user_inputs = {
|
||||
"file": {"id": "file-id"},
|
||||
"file_list": [{"id": "file-1"}, {"id": "file-2"}],
|
||||
"json": {"key": "value"},
|
||||
}
|
||||
|
||||
prepared = base_app_generator._prepare_user_inputs(
|
||||
user_inputs=user_inputs,
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
assert prepared["file"] == "file-object"
|
||||
assert prepared["file_list"] == ["file-1", "file-2"]
|
||||
assert prepared["json"] == {"key": "value"}
|
||||
|
||||
def test_prepare_user_inputs_rejects_invalid_dict_inputs(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text",
|
||||
label="text",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
base_app_generator._prepare_user_inputs(
|
||||
user_inputs={"text": {"unexpected": "dict"}},
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_prepare_user_inputs_rejects_invalid_list_inputs(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text",
|
||||
label="text",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
base_app_generator._prepare_user_inputs(
|
||||
user_inputs={"text": [{"unexpected": "dict"}]},
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_convert_to_event_stream(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True}
|
||||
|
||||
def _gen():
|
||||
yield {"delta": "hi"}
|
||||
yield "ping"
|
||||
|
||||
converted = list(base_app_generator.convert_to_event_stream(_gen()))
|
||||
|
||||
assert converted[0].startswith("data: ")
|
||||
assert "\n\n" in converted[0]
|
||||
assert converted[1] == "event: ping\n\n"
|
||||
|
||||
def test_get_draft_var_saver_factory_debugger(self):
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.enums import NodeType
|
||||
from models import Account
|
||||
|
||||
base_app_generator = BaseAppGenerator()
|
||||
account = Account(name="Tester", email="tester@example.com")
|
||||
account.id = "account-id"
|
||||
account.tenant_id = "tenant-id"
|
||||
|
||||
factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account)
|
||||
saver = factory(
|
||||
session=MagicMock(),
|
||||
app_id="app-id",
|
||||
node_id="node-id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="node-exec-id",
|
||||
)
|
||||
|
||||
assert saver is not None
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueErrorEvent
|
||||
|
||||
|
||||
class DummyQueueManager(AppQueueManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.published = []
|
||||
|
||||
def _publish(self, event, pub_from):
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
class TestBaseAppQueueManager:
|
||||
def test_init_requires_user_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
def test_publish_error_records_event(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
assert isinstance(manager.published[0][0], QueueErrorEvent)
|
||||
|
||||
def test_set_stop_flag_checks_user(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b"end-user-u1"
|
||||
AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1")
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_set_stop_flag_no_user_check(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id="t1")
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_is_stopped_reads_cache(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
mock_redis.get.return_value = b"1"
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
assert manager._is_stopped() is True
|
||||
|
||||
def test_check_for_sqlalchemy_models_raises(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
bad = SimpleNamespace(_sa_instance_state=True)
|
||||
with pytest.raises(TypeError):
|
||||
manager._check_for_sqlalchemy_models(bad)
|
||||
@@ -1,442 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class _DummyParameterRule:
|
||||
def __init__(self, name: str, use_template: str | None = None) -> None:
|
||||
self.name = name
|
||||
self.use_template = use_template
|
||||
|
||||
|
||||
class _QueueRecorder:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[object] = []
|
||||
|
||||
def publish(self, event, pub_from):
|
||||
_ = pub_from
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
class TestAppRunner:
|
||||
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[_DummyParameterRule("max_tokens")],
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="mock",
|
||||
model_schema=model_schema,
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ModelInstance",
|
||||
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80),
|
||||
)
|
||||
|
||||
runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")])
|
||||
|
||||
assert model_config.parameters["max_tokens"] == 20
|
||||
|
||||
def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
model_schema = SimpleNamespace(
|
||||
model_properties={},
|
||||
parameter_rules=[_DummyParameterRule("max_tokens")],
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="mock",
|
||||
model_schema=model_schema,
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ModelInstance",
|
||||
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10),
|
||||
)
|
||||
|
||||
assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1
|
||||
|
||||
def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True)
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None)
|
||||
|
||||
runner.direct_output(
|
||||
queue_manager=queue,
|
||||
app_generate_entity=app_generate_entity,
|
||||
prompt_messages=[],
|
||||
text="hi",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
|
||||
def test_handle_invoke_result_direct_publishes_end_event(self):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
llm_result = LLMResult(
|
||||
model="mock",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content="done"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=llm_result,
|
||||
queue_manager=queue,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
|
||||
def test_handle_invoke_result_invalid_type_raises(self):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=["unexpected"],
|
||||
queue_manager=queue,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_organize_prompt_messages_simple_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="chat", stop=["STOP"])
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hello",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.SimplePromptTransform.get_prompt",
|
||||
lambda self, **kwargs: (["simple-message"], ["simple-stop"]),
|
||||
)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["simple-message"]
|
||||
assert stop == ["simple-stop"]
|
||||
|
||||
def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="completion", stop=["<END>"])
|
||||
captured: dict[str, object] = {}
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="answer",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"),
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_advanced_prompt(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["advanced-completion-message"]
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["advanced-completion-message"]
|
||||
assert stop == ["<END>"]
|
||||
memory_config = captured["memory_config"]
|
||||
assert memory_config.role_prefix.user == "U"
|
||||
assert memory_config.role_prefix.assistant == "A"
|
||||
|
||||
def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="chat", stop=["<END>"])
|
||||
captured: dict[str, object] = {}
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_advanced_prompt(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["advanced-chat-message"]
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["advanced-chat-message"]
|
||||
assert stop == ["<END>"]
|
||||
assert len(captured["prompt_template"]) == 2
|
||||
|
||||
def test_organize_prompt_messages_advanced_missing_templates_raise(self):
|
||||
runner = AppRunner()
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"):
|
||||
runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=SimpleNamespace(mode="completion", stop=[]),
|
||||
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"):
|
||||
runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=SimpleNamespace(mode="chat", stop=[]),
|
||||
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
warning_logger = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger)
|
||||
|
||||
image_content = ImagePromptMessageContent(
|
||||
url="https://example.com/image.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
def _stream():
|
||||
yield LLMResultChunk(
|
||||
model="stream-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage.model_construct(
|
||||
content=[
|
||||
"a",
|
||||
TextPromptMessageContent(data="b"),
|
||||
SimpleNamespace(data="c"),
|
||||
image_content,
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=_stream(),
|
||||
queue_manager=queue,
|
||||
stream=True,
|
||||
agent=False,
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[0], QueueLLMChunkEvent)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
assert queue.events[-1].llm_result.message.content == "abc"
|
||||
warning_logger.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
exception_logger = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger)
|
||||
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_handle_multimodal_image_content",
|
||||
MagicMock(side_effect=RuntimeError("failed to save image")),
|
||||
)
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def _stream():
|
||||
yield LLMResultChunk(
|
||||
model="agent-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image.png",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
TextPromptMessageContent(data="done"),
|
||||
]
|
||||
),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result_stream(
|
||||
invoke_result=_stream(),
|
||||
queue_manager=queue,
|
||||
agent=True,
|
||||
message_id="message-id",
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[0], QueueAgentMessageEvent)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
assert queue.events[-1].llm_result.usage == usage
|
||||
exception_logger.assert_called_once()
|
||||
|
||||
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
class _ToggleBool:
|
||||
def __init__(self, values: list[bool]):
|
||||
self._values = values
|
||||
self._index = 0
|
||||
|
||||
def __bool__(self):
|
||||
value = self._values[min(self._index, len(self._values) - 1)]
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
content = SimpleNamespace(
|
||||
url=_ToggleBool([False, False]),
|
||||
base64_data=_ToggleBool([True, False]),
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock())
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock())
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session))
|
||||
|
||||
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock())
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id="message-id",
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_check_hosting_moderation_direct_output_called(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
app_generate_entity = SimpleNamespace(stream=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.HostingModerationFeature.check",
|
||||
lambda self, application_generate_entity, prompt_messages: True,
|
||||
)
|
||||
direct_output = MagicMock()
|
||||
monkeypatch.setattr(runner, "direct_output", direct_output)
|
||||
|
||||
result = runner.check_hosting_moderation(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=queue,
|
||||
prompt_messages=[],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert direct_output.called
|
||||
|
||||
def test_fill_in_inputs_from_external_data_tools(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ExternalDataFetch.fetch",
|
||||
lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"},
|
||||
)
|
||||
|
||||
result = runner.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
external_data_tools=[],
|
||||
inputs={},
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert result == {"foo": "bar"}
|
||||
|
||||
def test_moderation_for_inputs_returns_result(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.InputModeration.check",
|
||||
lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""),
|
||||
)
|
||||
app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None)
|
||||
|
||||
result = runner.moderation_for_inputs(
|
||||
app_id="app",
|
||||
tenant_id="tenant",
|
||||
app_generate_entity=app_generate_entity,
|
||||
inputs={},
|
||||
query="q",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert result == (True, {}, "")
|
||||
|
||||
def test_query_app_annotations_to_reply(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.AnnotationReplyFeature.query",
|
||||
lambda self, app_record, message, query, user_id, invoke_from: "reply",
|
||||
)
|
||||
|
||||
response = runner.query_app_annotations_to_reply(
|
||||
app_record=SimpleNamespace(),
|
||||
message=SimpleNamespace(),
|
||||
query="hello",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
assert response == "reply"
|
||||
@@ -1,7 +0,0 @@
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
|
||||
|
||||
class TestAppsExceptions:
|
||||
def test_generate_task_stopped_error(self):
|
||||
err = GenerateTaskStoppedError("stopped")
|
||||
assert str(err) == "stopped"
|
||||
@@ -13,11 +13,9 @@ from core.app.app_config.entities import (
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from models.model import AppMode, Conversation, Message
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
|
||||
|
||||
class DummyModelConf:
|
||||
@@ -127,55 +125,3 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity():
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
|
||||
|
||||
class TestMessageBasedAppGeneratorExtras:
|
||||
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
class _Pipeline:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
_ = kwargs
|
||||
|
||||
def process(self):
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline",
|
||||
_Pipeline,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)),
|
||||
queue_manager=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
user=SimpleNamespace(),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def test_get_app_model_config_requires_valid_config(self, monkeypatch):
|
||||
generator = MessageBasedAppGenerator()
|
||||
app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None)
|
||||
|
||||
with pytest.raises(AppModelConfigBrokenError):
|
||||
generator._get_app_model_config(app_model, conversation=None)
|
||||
|
||||
conversation = SimpleNamespace(app_model_config_id="missing-id")
|
||||
monkeypatch.setattr(
|
||||
message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None))
|
||||
)
|
||||
|
||||
with pytest.raises(AppModelConfigBrokenError):
|
||||
generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation)
|
||||
|
||||
def test_get_conversation_introduction_handles_missing_inputs(self):
|
||||
app_config = _make_app_config(AppMode.CHAT)
|
||||
app_config.additional_features.opening_statement = "Hello {{name}}"
|
||||
entity = _make_chat_generate_entity(app_config)
|
||||
entity.inputs = {}
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
assert generator._get_conversation_introduction(entity) == "Hello {name}"
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent
|
||||
|
||||
|
||||
class TestMessageBasedAppQueueManager:
|
||||
def test_publish_stops_on_terminal_events(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager.stop_listen = Mock()
|
||||
manager._is_stopped = Mock(return_value=False)
|
||||
|
||||
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock())
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
def test_publish_raises_when_stopped(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager._is_stopped = Mock(return_value=True)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def test_publish_enqueues_message_end(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager._is_stopped = Mock(return_value=False)
|
||||
manager.stop_listen = Mock()
|
||||
|
||||
manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
assert manager._q.qsize() == 1
|
||||
@@ -1,29 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestMessageGenerator:
|
||||
def test_get_response_topic(self):
|
||||
channel = Mock()
|
||||
channel.topic.return_value = "topic"
|
||||
|
||||
with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel):
|
||||
topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1")
|
||||
|
||||
assert topic == "topic"
|
||||
expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1")
|
||||
channel.topic.assert_called_once_with(expected_key)
|
||||
|
||||
def test_retrieve_events_passes_arguments(self):
|
||||
with (
|
||||
patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"),
|
||||
patch(
|
||||
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
|
||||
) as mock_stream,
|
||||
):
|
||||
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
|
||||
|
||||
assert events == [{"event": "ping"}]
|
||||
mock_stream.assert_called_once()
|
||||
@@ -6,7 +6,6 @@ import queue
|
||||
import pytest
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
@@ -79,30 +78,3 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
|
||||
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_normalize_terminal_events_defaults():
|
||||
assert _normalize_terminal_events(None) == {
|
||||
StreamEvent.WORKFLOW_FINISHED.value,
|
||||
StreamEvent.WORKFLOW_PAUSED.value,
|
||||
}
|
||||
|
||||
|
||||
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
|
||||
|
||||
def fake_time():
|
||||
return times.pop(0)
|
||||
|
||||
monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time)
|
||||
|
||||
generator = stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=10.0,
|
||||
ping_interval=1.0,
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
# next receive yields None -> ping interval triggers
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
|
||||
@@ -1,261 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
||||
|
||||
class TestWorkflowBasedAppRunner:
|
||||
def test_resolve_user_from(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT
|
||||
assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT
|
||||
assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER
|
||||
|
||||
def test_init_graph_validates_graph_structure(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="nodes or edges not found"):
|
||||
runner._init_graph(
|
||||
graph_config={},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="nodes in workflow graph must be a list"):
|
||||
runner._init_graph(
|
||||
graph_config={"nodes": {}, "edges": []},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="edges in workflow graph must be a list"):
|
||||
runner._init_graph(
|
||||
graph_config={"nodes": [], "edges": {}},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
def test_prepare_single_node_execution_requires_run(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
workflow = SimpleNamespace(environment_variables=[], graph_dict={})
|
||||
|
||||
with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"):
|
||||
runner._prepare_single_node_execution(workflow, None, None)
|
||||
|
||||
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
graph_config = {
|
||||
"nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}],
|
||||
"edges": [],
|
||||
}
|
||||
workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.Graph.init",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
|
||||
class _NodeCls:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(graph_config, config):
|
||||
return {}
|
||||
|
||||
from core.app.apps import workflow_app_runner
|
||||
|
||||
monkeypatch.setitem(
|
||||
workflow_app_runner.NODE_TYPE_CLASSES_MAPPING,
|
||||
NodeType.START,
|
||||
{"1": _NodeCls},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.load_into_variable_pool",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id="node-1",
|
||||
user_inputs={},
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
assert variable_pool is graph_runtime_state.variable_pool
|
||||
|
||||
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
|
||||
published: list[object] = []
|
||||
|
||||
class _QueueManager:
|
||||
def publish(self, event, publish_from):
|
||||
published.append((event, publish_from))
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_runtime_state.register_paused_node("node-1")
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
emails: list[dict] = []
|
||||
|
||||
class _Dispatch:
|
||||
def apply_async(self, *, kwargs, queue):
|
||||
emails.append({"kwargs": kwargs, "queue": queue})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.dispatch_human_input_email_task",
|
||||
_Dispatch(),
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Node",
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, GraphRunStartedEvent())
|
||||
runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True}))
|
||||
runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={}))
|
||||
|
||||
assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published)
|
||||
assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published)
|
||||
paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent))
|
||||
assert paused_event.paused_nodes == ["node-1"]
|
||||
assert emails
|
||||
|
||||
def test_handle_node_events_publishes_queue_events(self):
|
||||
published: list[object] = []
|
||||
|
||||
class _QueueManager:
|
||||
def publish(self, event, publish_from):
|
||||
published.append(event)
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
node_title="Start",
|
||||
start_at=datetime.utcnow(),
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunAgentLogEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
message_id="msg",
|
||||
label="label",
|
||||
node_execution_id="exec",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status="done",
|
||||
data={},
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunIterationSucceededEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Iter",
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={"ok": True},
|
||||
metadata={},
|
||||
steps=1,
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunLoopFailedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Loop",
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
steps=1,
|
||||
error="boom",
|
||||
),
|
||||
)
|
||||
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in published)
|
||||
assert any(isinstance(event, QueueAgentLogEvent) for event in published)
|
||||
assert any(isinstance(event, QueueIterationCompletedEvent) for event in published)
|
||||
assert any(isinstance(event, QueueLoopCompletedEvent) for event in published)
|
||||
@@ -1,61 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestWorkflowAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow)
|
||||
|
||||
assert app_config.workflow_id == "wf-1"
|
||||
assert app_config.app_mode == AppMode.WORKFLOW
|
||||
|
||||
def test_config_validate_filters_keys(self):
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
# Support both positional and keyword arguments for config
|
||||
if "config" in kwargs:
|
||||
config = kwargs["config"]
|
||||
elif len(args) > 0:
|
||||
config = args[0]
|
||||
else:
|
||||
config = {}
|
||||
config[key] = value
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 3),
|
||||
),
|
||||
):
|
||||
filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={})
|
||||
|
||||
assert filtered["file_upload"] == 1
|
||||
assert filtered["text_to_speech"] == 2
|
||||
assert filtered["sensitive_word_avoidance"] == 3
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestWorkflowAppGeneratorValidation:
|
||||
def test_should_prepare_user_inputs(self):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
assert generator._should_prepare_user_inputs({}) is True
|
||||
assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False
|
||||
|
||||
def test_single_iteration_generate_validates_args(self):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
with pytest.raises(ValueError, match="node_id is required"):
|
||||
generator.single_iteration_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="",
|
||||
user=SimpleNamespace(),
|
||||
args={"inputs": {}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="inputs is required"):
|
||||
generator.single_iteration_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="node",
|
||||
user=SimpleNamespace(),
|
||||
args={},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_single_loop_generate_validates_args(self):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
with pytest.raises(ValueError, match="node_id is required"):
|
||||
generator.single_loop_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="",
|
||||
user=SimpleNamespace(),
|
||||
args=SimpleNamespace(inputs={}),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="inputs is required"):
|
||||
generator.single_loop_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="node",
|
||||
user=SimpleNamespace(),
|
||||
args=SimpleNamespace(inputs=None),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowAppGeneratorHandleResponse:
|
||||
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
class _Pipeline:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
_ = kwargs
|
||||
|
||||
def process(self):
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline",
|
||||
_Pipeline,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(),
|
||||
queue_manager=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowAppGeneratorGenerate:
|
||||
def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config",
|
||||
lambda app_model, workflow: app_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.FileUploadConfigManager.convert",
|
||||
lambda features_dict, is_vision=False: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.file_factory.build_from_mappings",
|
||||
lambda **kwargs: [],
|
||||
)
|
||||
DummyTraceQueueManager = type(
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.TraceQueueManager",
|
||||
DummyTraceQueueManager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.sessionmaker",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
|
||||
prepare_inputs = pytest.fail
|
||||
monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs())
|
||||
|
||||
monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True})
|
||||
|
||||
result = generator.generate(
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
|
||||
workflow=SimpleNamespace(features_dict={}),
|
||||
user=SimpleNamespace(id="user", session_id="session"),
|
||||
args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
@@ -1,33 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent
|
||||
|
||||
|
||||
class TestWorkflowAppQueueManager:
|
||||
def test_publish_stop_events_trigger_stop(self):
|
||||
manager = WorkflowAppQueueManager(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
app_mode="workflow",
|
||||
)
|
||||
manager._is_stopped = lambda: True
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def test_publish_non_stop_event_does_not_raise(self):
|
||||
manager = WorkflowAppQueueManager(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
app_mode="workflow",
|
||||
)
|
||||
|
||||
manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
@@ -1,9 +0,0 @@
|
||||
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
|
||||
|
||||
|
||||
class TestWorkflowErrors:
|
||||
def test_workflow_paused_in_blocking_mode_error_attributes(self):
|
||||
err = WorkflowPausedInBlockingModeError()
|
||||
assert err.error_code == "workflow_paused_in_blocking_mode"
|
||||
assert err.code == 400
|
||||
assert "blocking response mode" in err.description
|
||||
@@ -1,133 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestWorkflowGenerateResponseConverter:
|
||||
def test_blocking_full_response(self):
|
||||
blocking = WorkflowAppBlockingResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="exec-1",
|
||||
workflow_id="wf-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"ok": True},
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=10,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
assert response["workflow_run_id"] == "r1"
|
||||
|
||||
def test_stream_simple_response_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1"))
|
||||
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start)
|
||||
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish)
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom"))
|
||||
)
|
||||
|
||||
converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert converted[0] == "ping"
|
||||
assert converted[1]["event"] == "node_started"
|
||||
assert converted[2]["event"] == "node_finished"
|
||||
assert converted[3]["event"] == "error"
|
||||
|
||||
def test_convert_stream_simple_response_handles_ping_and_nodes(self):
|
||||
def _gen():
|
||||
yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task"))
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run",
|
||||
stream_response=NodeStartStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="node-exec",
|
||||
node_id="node",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run",
|
||||
stream_response=NodeFinishStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="node-exec",
|
||||
node_id="node",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
elapsed_time=1.0,
|
||||
error=None,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen()))
|
||||
|
||||
assert chunks[0] == "ping"
|
||||
assert chunks[1]["event"] == "node_started"
|
||||
assert chunks[2]["event"] == "node_finished"
|
||||
|
||||
def test_convert_stream_full_response_handles_error(self):
|
||||
def _gen():
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run",
|
||||
stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")),
|
||||
)
|
||||
|
||||
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen()))
|
||||
|
||||
assert chunks[0]["event"] == "error"
|
||||
@@ -1,868 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from dify_graph.enums import NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _make_pipeline():
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = SimpleNamespace(id="user", session_id="session")
|
||||
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
user=user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_handles_pause(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _gen():
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
def test_to_blocking_response_handles_finish(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _gen():
|
||||
yield WorkflowFinishStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id="run",
|
||||
workflow_id="workflow-id",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"ok": True},
|
||||
error=None,
|
||||
elapsed_time=1.0,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.outputs == {"ok": True}
|
||||
|
||||
def test_listen_audio_msg_returns_audio_stream(self):
|
||||
pipeline = _make_pipeline()
|
||||
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
|
||||
|
||||
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
|
||||
|
||||
assert isinstance(response, MessageAudioStreamResponse)
|
||||
|
||||
def test_handle_ping_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
|
||||
|
||||
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
|
||||
|
||||
assert isinstance(responses[0], PingStreamResponse)
|
||||
|
||||
def test_handle_error_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
|
||||
|
||||
assert isinstance(responses[0], ValueError)
|
||||
|
||||
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None)
|
||||
|
||||
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
|
||||
|
||||
assert pipeline._workflow_execution_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
def test_handle_node_succeeded_event_saves_output(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
assert responses == ["done"]
|
||||
|
||||
def test_handle_workflow_failed_event_yields_error(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1))
|
||||
)
|
||||
|
||||
assert responses[0] == "finish"
|
||||
|
||||
def test_handle_text_chunk_event_publishes_tts(self):
|
||||
pipeline = _make_pipeline()
|
||||
published: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
published.append(message)
|
||||
|
||||
event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"])
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
|
||||
)
|
||||
|
||||
assert responses[0].data.text == "hi"
|
||||
assert published == [queue_message]
|
||||
|
||||
def test_dispatch_event_handles_node_failed(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
|
||||
event = QueueNodeFailedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._dispatch_event(event)) == ["done"]
|
||||
|
||||
def test_handle_stop_event_yields_finish(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_workflow_failed_and_stop_events(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
|
||||
)
|
||||
)
|
||||
|
||||
assert responses == ["finish"]
|
||||
|
||||
def test_save_workflow_app_log_created_from(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
pipeline._user_id = "user"
|
||||
added: list[object] = []
|
||||
|
||||
class _Session:
|
||||
def add(self, item):
|
||||
added.append(item)
|
||||
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
|
||||
assert added
|
||||
|
||||
def test_iteration_loop_and_human_input_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter"
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
|
||||
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
|
||||
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
|
||||
pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log"
|
||||
|
||||
iter_start = QueueIterationStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_next = QueueIterationNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_done = QueueIterationCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_start = QueueLoopStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_next = QueueLoopNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_done = QueueLoopCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
filled_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
action_id="action",
|
||||
action_text="action",
|
||||
)
|
||||
timeout_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
expiration_time=datetime.utcnow(),
|
||||
)
|
||||
agent_event = QueueAgentLogEvent(
|
||||
id="log",
|
||||
label="label",
|
||||
node_execution_id="exec",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status="done",
|
||||
data={},
|
||||
metadata={},
|
||||
node_id="node",
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"]
|
||||
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"]
|
||||
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"]
|
||||
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"]
|
||||
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
|
||||
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
|
||||
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
|
||||
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
|
||||
assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"]
|
||||
|
||||
def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {
|
||||
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
|
||||
}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
|
||||
|
||||
class _Publisher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.calls = 0
|
||||
|
||||
def check_and_get_audio(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return AudioTrunk(status="stream", audio="data")
|
||||
if self.calls == 2:
|
||||
return None
|
||||
return AudioTrunk(status="finish", audio="")
|
||||
|
||||
def publish(self, message):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
|
||||
_Publisher,
|
||||
)
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
|
||||
assert any(isinstance(item, MessageAudioStreamResponse) for item in responses)
|
||||
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
|
||||
|
||||
def test_init_with_end_user_sets_role_and_system_user(self):
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="end-user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None)
|
||||
end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id")
|
||||
end_user.id = "end-user-id"
|
||||
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=end_user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
assert pipeline._created_by_role == CreatorUserRole.END_USER
|
||||
assert pipeline._workflow_system_variables.user_id == "session-id"
|
||||
|
||||
def test_process_returns_stream_and_blocking_variants(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.stream = True
|
||||
pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
|
||||
|
||||
stream_response = list(pipeline.process())
|
||||
assert len(stream_response) == 1
|
||||
assert stream_response[0].workflow_run_id is None
|
||||
|
||||
pipeline._base_task_pipeline.stream = False
|
||||
pipeline._wrapper_process_stream_response = lambda **kwargs: iter(
|
||||
[
|
||||
WorkflowFinishStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
blocking_response = pipeline.process()
|
||||
assert blocking_response.workflow_run_id == "run-id"
|
||||
|
||||
def test_to_blocking_response_handles_error_and_unexpected_end(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _error_gen():
|
||||
yield ErrorStreamResponse(task_id="task", err=ValueError("boom"))
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
pipeline._to_blocking_response(_error_gen())
|
||||
|
||||
def _unexpected_gen():
|
||||
yield PingStreamResponse(task_id="task")
|
||||
|
||||
with pytest.raises(ValueError, match="queue listening stopped unexpectedly"):
|
||||
pipeline._to_blocking_response(_unexpected_gen())
|
||||
|
||||
def test_to_stream_response_tracks_workflow_run_id(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _gen():
|
||||
yield WorkflowStartStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
inputs={},
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
yield PingStreamResponse(task_id="task")
|
||||
|
||||
stream_responses = list(pipeline._to_stream_response(_gen()))
|
||||
assert stream_responses[0].workflow_run_id == "run-id"
|
||||
assert stream_responses[1].workflow_run_id == "run-id"
|
||||
|
||||
def test_listen_audio_msg_returns_none_without_publisher(self):
|
||||
pipeline = _make_pipeline()
|
||||
assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None
|
||||
|
||||
def test_wrapper_process_stream_response_without_tts(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
assert responses == [PingStreamResponse(task_id="task")]
|
||||
|
||||
def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {
|
||||
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
|
||||
}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([])
|
||||
|
||||
sleep_spy = []
|
||||
|
||||
class _Publisher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.calls = 0
|
||||
|
||||
def check_and_get_audio(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return None
|
||||
return AudioTrunk(status="finish", audio="")
|
||||
|
||||
def publish(self, message):
|
||||
_ = message
|
||||
|
||||
time_values = iter([0.0, 0.0, 0.2])
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values))
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
|
||||
_Publisher,
|
||||
)
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
|
||||
assert sleep_spy
|
||||
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
|
||||
|
||||
def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {
|
||||
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
|
||||
}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([])
|
||||
|
||||
class _Publisher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.called = False
|
||||
|
||||
def check_and_get_audio(self):
|
||||
if not self.called:
|
||||
self.called = True
|
||||
raise RuntimeError("tts failure")
|
||||
return AudioTrunk(status="finish", audio="")
|
||||
|
||||
def publish(self, message):
|
||||
_ = message
|
||||
|
||||
logger_exception = []
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.logger.exception",
|
||||
lambda *args, **kwargs: logger_exception.append((args, kwargs)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
|
||||
_Publisher,
|
||||
)
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
|
||||
assert logger_exception
|
||||
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
|
||||
|
||||
def test_database_session_rolls_back_on_error(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"commit": 0, "rollback": 0}
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def commit(self):
|
||||
calls["commit"] += 1
|
||||
|
||||
def rollback(self):
|
||||
calls["rollback"] += 1
|
||||
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
|
||||
|
||||
with pytest.raises(RuntimeError, match="db error"):
|
||||
with pipeline._database_session():
|
||||
raise RuntimeError("db error")
|
||||
|
||||
assert calls["commit"] == 0
|
||||
assert calls["rollback"] == 1
|
||||
|
||||
def test_node_retry_and_started_handlers_cover_none_and_value(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
|
||||
retry_event = QueueNodeRetryEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_title="title",
|
||||
node_type=NodeType.LLM,
|
||||
node_run_index=1,
|
||||
start_at=datetime.utcnow(),
|
||||
provider_type="provider",
|
||||
provider_id="provider-id",
|
||||
error="error",
|
||||
retry_index=1,
|
||||
)
|
||||
started_event = QueueNodeStartedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_title="title",
|
||||
node_type=NodeType.LLM,
|
||||
node_run_index=1,
|
||||
start_at=datetime.utcnow(),
|
||||
provider_type="provider",
|
||||
provider_id="provider-id",
|
||||
)
|
||||
|
||||
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None
|
||||
assert list(pipeline._handle_node_retry_event(retry_event)) == []
|
||||
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry"
|
||||
assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"]
|
||||
|
||||
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None
|
||||
assert list(pipeline._handle_node_started_event(started_event)) == []
|
||||
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started"
|
||||
assert list(pipeline._handle_node_started_event(started_event)) == ["started"]
|
||||
|
||||
def test_handle_node_exception_event_saves_output(self):
|
||||
pipeline = _make_pipeline()
|
||||
saved_ids: list[str] = []
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id)
|
||||
|
||||
event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec-id",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="boom",
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_failed_events(event))
|
||||
assert responses == ["failed"]
|
||||
assert saved_ids == ["exec-id"]
|
||||
|
||||
def test_success_partial_and_pause_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"]
|
||||
assert list(
|
||||
pipeline._handle_workflow_partial_success_event(
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={})
|
||||
)
|
||||
) == ["finish"]
|
||||
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [
|
||||
"pause-a",
|
||||
"pause-b",
|
||||
]
|
||||
pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"])
|
||||
assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"]
|
||||
|
||||
def test_text_chunk_handler_returns_empty_when_text_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None)
|
||||
assert list(pipeline._handle_text_chunk_event(event)) == []
|
||||
|
||||
def test_dispatch_event_direct_failed_and_unhandled_paths(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
|
||||
assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"]
|
||||
|
||||
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"])
|
||||
assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [
|
||||
"workflow-failed"
|
||||
]
|
||||
|
||||
assert list(pipeline._dispatch_event(SimpleNamespace())) == []
|
||||
|
||||
def test_process_stream_response_main_match_paths_and_cleanup(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[
|
||||
SimpleNamespace(event=QueueWorkflowStartedEvent()),
|
||||
SimpleNamespace(event=QueueTextChunkEvent(text="hello")),
|
||||
SimpleNamespace(event=QueuePingEvent()),
|
||||
SimpleNamespace(event=QueueErrorEvent(error="e")),
|
||||
]
|
||||
)
|
||||
pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"])
|
||||
pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"])
|
||||
pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"])
|
||||
pipeline._handle_error_event = lambda event, **kwargs: iter(["error"])
|
||||
publisher_calls: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
publisher_calls.append(message)
|
||||
|
||||
responses = list(pipeline._process_stream_response(tts_publisher=_Publisher()))
|
||||
assert responses == ["started", "text", "dispatched", "error"]
|
||||
assert publisher_calls == [None]
|
||||
|
||||
def test_process_stream_response_break_paths(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))]
|
||||
)
|
||||
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"])
|
||||
assert list(pipeline._process_stream_response()) == ["failed"]
|
||||
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))]
|
||||
)
|
||||
pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"])
|
||||
assert list(pipeline._process_stream_response()) == ["paused"]
|
||||
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))]
|
||||
)
|
||||
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"])
|
||||
assert list(pipeline._process_stream_response()) == ["stopped"]
|
||||
|
||||
def test_save_workflow_app_log_covers_invoke_from_variants(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._user_id = "user-id"
|
||||
added: list[object] = []
|
||||
|
||||
class _Session:
|
||||
def add(self, item):
|
||||
added.append(item)
|
||||
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
assert added[-1].created_from == "installed-app"
|
||||
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
assert added[-1].created_from == "web-app"
|
||||
|
||||
count_before = len(added)
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
assert len(added) == count_before
|
||||
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None)
|
||||
assert len(added) == count_before
|
||||
|
||||
def test_save_output_for_event_writes_draft_variables(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
saver_calls: list[tuple[object, object]] = []
|
||||
captured_factory_args: dict[str, object] = {}
|
||||
|
||||
class _Saver:
|
||||
def save(self, process_data, outputs):
|
||||
saver_calls.append((process_data, outputs))
|
||||
|
||||
def _factory(**kwargs):
|
||||
captured_factory_args.update(kwargs)
|
||||
return _Saver()
|
||||
|
||||
class _Begin:
|
||||
def __enter__(self):
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return _Begin()
|
||||
|
||||
pipeline._draft_var_saver_factory = _factory
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="exec-id",
|
||||
node_id="node-id",
|
||||
node_type=NodeType.START,
|
||||
in_loop_id="loop-id",
|
||||
start_at=datetime.utcnow(),
|
||||
process_data={"k": "v"},
|
||||
outputs={"out": 1},
|
||||
)
|
||||
pipeline._save_output_for_event(event=event, node_execution_id="exec-id")
|
||||
|
||||
assert captured_factory_args["node_execution_id"] == "exec-id"
|
||||
assert captured_factory_args["enclosing_node_id"] == "loop-id"
|
||||
assert saver_calls == [({"k": "v"}, {"out": 1})]
|
||||
@@ -1,390 +0,0 @@
|
||||
import base64
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.base.tts.app_generator_tts_publisher import (
|
||||
AppGeneratorTTSPublisher,
|
||||
AudioTrunk,
|
||||
_invoice_tts,
|
||||
_process_future,
|
||||
)
|
||||
|
||||
# =========================
|
||||
# Fixtures
|
||||
# =========================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(mocker):
|
||||
model = mocker.MagicMock()
|
||||
model.invoke_tts.return_value = [b"audio1", b"audio2"]
|
||||
model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}]
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_manager(mocker, mock_model_instance):
|
||||
manager = mocker.MagicMock()
|
||||
manager.get_default_model_instance.return_value = mock_model_instance
|
||||
mocker.patch(
|
||||
"core.base.tts.app_generator_tts_publisher.ModelManager",
|
||||
return_value=manager,
|
||||
)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_threads(mocker):
|
||||
"""Prevent real threads from starting during tests"""
|
||||
mocker.patch("threading.Thread.start", return_value=None)
|
||||
|
||||
|
||||
# =========================
|
||||
# AudioTrunk Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestAudioTrunk:
|
||||
def test_audio_trunk_initialization(self):
|
||||
trunk = AudioTrunk("responding", b"data")
|
||||
assert trunk.status == "responding"
|
||||
assert trunk.audio == b"data"
|
||||
|
||||
|
||||
# =========================
|
||||
# _invoice_tts Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestInvoiceTTS:
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[None, "", " "],
|
||||
)
|
||||
def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance):
|
||||
result = _invoice_tts(text, mock_model_instance, "tenant", "voice1")
|
||||
assert result is None
|
||||
mock_model_instance.invoke_tts.assert_not_called()
|
||||
|
||||
def test_invoice_tts_valid_text(self, mock_model_instance):
|
||||
result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1")
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="hello",
|
||||
user="responding_tts",
|
||||
tenant_id="tenant",
|
||||
voice="voice1",
|
||||
)
|
||||
assert result == [b"audio1", b"audio2"]
|
||||
|
||||
|
||||
# =========================
|
||||
# _process_future Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestProcessFuture:
|
||||
def test_process_future_normal_flow(self):
|
||||
future_queue = queue.Queue()
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
future = MagicMock()
|
||||
future.result.return_value = [b"abc"]
|
||||
|
||||
future_queue.put(future)
|
||||
future_queue.put(None)
|
||||
|
||||
_process_future(future_queue, audio_queue)
|
||||
|
||||
first = audio_queue.get()
|
||||
assert first.status == "responding"
|
||||
assert first.audio == base64.b64encode(b"abc")
|
||||
|
||||
finish = audio_queue.get()
|
||||
assert finish.status == "finish"
|
||||
|
||||
def test_process_future_empty_result(self):
|
||||
future_queue = queue.Queue()
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
future = MagicMock()
|
||||
future.result.return_value = None
|
||||
|
||||
future_queue.put(future)
|
||||
future_queue.put(None)
|
||||
|
||||
_process_future(future_queue, audio_queue)
|
||||
|
||||
finish = audio_queue.get()
|
||||
assert finish.status == "finish"
|
||||
|
||||
def test_process_future_exception(self, mocker):
|
||||
future_queue = queue.Queue()
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
future = MagicMock()
|
||||
future.result.side_effect = Exception("error")
|
||||
|
||||
future_queue.put(future)
|
||||
|
||||
_process_future(future_queue, audio_queue)
|
||||
|
||||
finish = audio_queue.get()
|
||||
assert finish.status == "finish"
|
||||
|
||||
|
||||
# =========================
|
||||
# AppGeneratorTTSPublisher Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestAppGeneratorTTSPublisher:
|
||||
def test_initialization_valid_voice(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
assert publisher.voice == "voice1"
|
||||
assert publisher.max_sentence == 2
|
||||
assert publisher.msg_text == ""
|
||||
|
||||
def test_initialization_invalid_voice_fallback(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice")
|
||||
assert publisher.voice == "voice1"
|
||||
|
||||
def test_publish_puts_message_in_queue(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
message = MagicMock()
|
||||
publisher.publish(message)
|
||||
assert publisher._msg_queue.get() == message
|
||||
|
||||
def test_check_and_get_audio_no_audio(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
result = publisher.check_and_get_audio()
|
||||
assert result is None
|
||||
|
||||
def test_check_and_get_audio_non_finish_event(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
trunk = AudioTrunk("responding", b"abc")
|
||||
publisher._audio_queue.put(trunk)
|
||||
|
||||
result = publisher.check_and_get_audio()
|
||||
|
||||
assert result.status == "responding"
|
||||
assert publisher._last_audio_event == trunk
|
||||
|
||||
def test_check_and_get_audio_finish_event(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
finish_trunk = AudioTrunk("finish", b"")
|
||||
publisher._audio_queue.put(finish_trunk)
|
||||
|
||||
result = publisher.check_and_get_audio()
|
||||
|
||||
assert result.status == "finish"
|
||||
publisher.executor.shutdown.assert_called_once()
|
||||
|
||||
def test_check_and_get_audio_cached_finish(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
publisher._last_audio_event = AudioTrunk("finish", b"")
|
||||
|
||||
result = publisher.check_and_get_audio()
|
||||
|
||||
assert result.status == "finish"
|
||||
publisher.executor.shutdown.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "expected_sentences", "expected_remaining"),
|
||||
[
|
||||
("Hello world.", ["Hello world."], ""),
|
||||
("Hello world! How are you?", ["Hello world!", " How are you?"], ""),
|
||||
("No punctuation", [], "No punctuation"),
|
||||
("", [], ""),
|
||||
],
|
||||
)
|
||||
def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
sentences, remaining = publisher._extract_sentence(text)
|
||||
assert sentences == expected_sentences
|
||||
assert remaining == expected_remaining
|
||||
|
||||
def test_runtime_handles_none_message_with_buffer(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
publisher.msg_text = "Hello."
|
||||
|
||||
publisher._msg_queue.put(None)
|
||||
publisher._runtime()
|
||||
|
||||
publisher.executor.submit.assert_called_once()
|
||||
|
||||
def test_runtime_handles_none_message_without_buffer(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
publisher.msg_text = " "
|
||||
|
||||
publisher._msg_queue.put(None)
|
||||
publisher._runtime()
|
||||
|
||||
publisher.executor.submit.assert_not_called()
|
||||
|
||||
def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
# Force sentence extraction to hit threshold condition
|
||||
mocker.patch.object(
|
||||
publisher,
|
||||
"_extract_sentence",
|
||||
return_value=(["Hello world.", " Second sentence."], ""),
|
||||
)
|
||||
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueTextChunkEvent)
|
||||
event.event.text = "Hello world. Second sentence."
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.executor.submit.called
|
||||
|
||||
def test_runtime_handles_text_chunk_event(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueTextChunkEvent)
|
||||
event.event.text = "Hello world."
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.executor.submit.called
|
||||
|
||||
def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueNodeSucceededEvent)
|
||||
event.event.outputs = {"output": "Hello world."}
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.executor.submit.called
|
||||
|
||||
def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueNodeSucceededEvent)
|
||||
event.event.outputs = None
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
publisher.executor.submit.assert_not_called()
|
||||
|
||||
def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
chunk = LLMResultChunk(
|
||||
model="model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="Hello "),
|
||||
ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"),
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
|
||||
|
||||
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.msg_text == "Hello "
|
||||
|
||||
def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
chunk = LLMResultChunk(
|
||||
model="model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
),
|
||||
)
|
||||
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
|
||||
|
||||
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.msg_text == ""
|
||||
|
||||
def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueTextChunkEvent)
|
||||
event.event.text = "Hello world. Another sentence."
|
||||
|
||||
mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None))
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.msg_text == ""
|
||||
|
||||
def test_runtime_exception_path(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher._msg_queue = MagicMock()
|
||||
publisher._msg_queue.get.side_effect = Exception("error")
|
||||
|
||||
publisher._runtime()
|
||||
@@ -1,197 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.callback_handler.agent_tool_callback_handler as module
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_debug(mocker):
|
||||
mocker.patch.object(module.dify_config, "DEBUG", True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_debug(mocker):
|
||||
mocker.patch.object(module.dify_config, "DEBUG", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_print(mocker):
|
||||
return mocker.patch("builtins.print")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
return module.DifyAgentCallbackHandler(color="blue")
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# get_colored_text Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestGetColoredText:
|
||||
@pytest.mark.parametrize(
|
||||
("color", "expected_code"),
|
||||
[
|
||||
("blue", "36;1"),
|
||||
("yellow", "33;1"),
|
||||
("pink", "38;5;200"),
|
||||
("green", "32;1"),
|
||||
("red", "31;1"),
|
||||
],
|
||||
)
|
||||
def test_get_colored_text_valid_colors(self, color, expected_code):
|
||||
text = "hello"
|
||||
result = module.get_colored_text(text, color)
|
||||
assert expected_code in result
|
||||
assert text in result
|
||||
assert result.endswith("\u001b[0m")
|
||||
|
||||
def test_get_colored_text_invalid_color_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
module.get_colored_text("hello", "invalid")
|
||||
|
||||
def test_get_colored_text_empty_string(self):
|
||||
result = module.get_colored_text("", "green")
|
||||
assert "\u001b[" in result
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# print_text Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestPrintText:
|
||||
def test_print_text_without_color(self, mock_print):
|
||||
module.print_text("hello")
|
||||
mock_print.assert_called_once_with("hello", end="", file=None)
|
||||
|
||||
def test_print_text_with_color(self, mocker, mock_print):
|
||||
mock_get_color = mocker.patch(
|
||||
"core.callback_handler.agent_tool_callback_handler.get_colored_text",
|
||||
return_value="colored_text",
|
||||
)
|
||||
|
||||
module.print_text("hello", color="green")
|
||||
|
||||
mock_get_color.assert_called_once_with("hello", "green")
|
||||
mock_print.assert_called_once_with("colored_text", end="", file=None)
|
||||
|
||||
def test_print_text_with_file_flush(self, mocker):
|
||||
mock_file = MagicMock()
|
||||
mock_print = mocker.patch("builtins.print")
|
||||
|
||||
module.print_text("hello", file=mock_file)
|
||||
|
||||
mock_print.assert_called_once_with("hello", end="", file=mock_file)
|
||||
mock_file.flush.assert_called_once()
|
||||
|
||||
def test_print_text_with_end_parameter(self, mock_print):
|
||||
module.print_text("hello", end="\n")
|
||||
mock_print.assert_called_once_with("hello", end="\n", file=None)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# DifyAgentCallbackHandler Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestDifyAgentCallbackHandler:
|
||||
def test_init_default_color(self):
|
||||
handler = module.DifyAgentCallbackHandler()
|
||||
assert handler.color == "green"
|
||||
assert handler.current_loop == 1
|
||||
|
||||
def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_start("tool1", {"a": 1})
|
||||
|
||||
mock_print_text.assert_called()
|
||||
|
||||
def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_start("tool1", {"a": 1})
|
||||
|
||||
mock_print_text.assert_not_called()
|
||||
|
||||
def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
mock_trace_manager = MagicMock()
|
||||
|
||||
handler.on_tool_end(
|
||||
tool_name="tool1",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="output",
|
||||
message_id="msg1",
|
||||
timer=123,
|
||||
trace_manager=mock_trace_manager,
|
||||
)
|
||||
|
||||
assert mock_print_text.call_count >= 1
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_end(
|
||||
tool_name="tool1",
|
||||
tool_inputs={},
|
||||
tool_outputs="output",
|
||||
)
|
||||
|
||||
assert mock_print_text.call_count >= 1
|
||||
|
||||
def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_error(Exception("error"))
|
||||
|
||||
mock_print_text.assert_called_once()
|
||||
|
||||
def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_error(Exception("error"))
|
||||
|
||||
mock_print_text.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("thought", ["thinking", ""])
|
||||
def test_on_agent_start(self, handler, enable_debug, mocker, thought):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_agent_start(thought)
|
||||
|
||||
mock_print_text.assert_called()
|
||||
|
||||
def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
current_loop = handler.current_loop
|
||||
handler.on_agent_finish()
|
||||
|
||||
assert handler.current_loop == current_loop + 1
|
||||
mock_print_text.assert_called()
|
||||
|
||||
def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_datasource_start("ds1", {"x": 1})
|
||||
|
||||
mock_print_text.assert_called_once()
|
||||
|
||||
def test_ignore_agent_property(self, disable_debug, handler):
|
||||
assert handler.ignore_agent is True
|
||||
|
||||
def test_ignore_chat_model_property(self, disable_debug, handler):
|
||||
assert handler.ignore_chat_model is True
|
||||
|
||||
def test_ignore_properties_when_debug_enabled(self, enable_debug, handler):
|
||||
assert handler.ignore_agent is False
|
||||
assert handler.ignore_chat_model is False
|
||||
@@ -1,162 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import (
|
||||
DatasetIndexToolCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(mocker):
|
||||
return mocker.Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler(mock_queue_manager, mocker):
|
||||
mocker.patch(
|
||||
"core.callback_handler.index_tool_callback_handler.db",
|
||||
)
|
||||
return DatasetIndexToolCallbackHandler(
|
||||
queue_manager=mock_queue_manager,
|
||||
app_id="app-1",
|
||||
message_id="msg-1",
|
||||
user_id="user-1",
|
||||
invoke_from=mocker.Mock(),
|
||||
)
|
||||
|
||||
|
||||
class TestOnQuery:
|
||||
@pytest.mark.parametrize(
|
||||
("invoke_from", "expected_role"),
|
||||
[
|
||||
(InvokeFrom.EXPLORE, "account"),
|
||||
(InvokeFrom.DEBUGGER, "account"),
|
||||
(InvokeFrom.WEB_APP, "end_user"),
|
||||
],
|
||||
)
|
||||
def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role):
|
||||
# Arrange
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
handler = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=mock_queue_manager,
|
||||
app_id="app-1",
|
||||
message_id="msg-1",
|
||||
user_id="user-1",
|
||||
invoke_from=mocker.Mock(),
|
||||
)
|
||||
|
||||
handler._invoke_from = invoke_from
|
||||
|
||||
# Act
|
||||
handler.on_query("test query", "dataset-1")
|
||||
|
||||
# Assert
|
||||
mock_db.session.add.assert_called_once()
|
||||
dataset_query = mock_db.session.add.call_args.args[0]
|
||||
assert dataset_query.created_by_role == expected_role
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_query_none_values(self, mocker, mock_queue_manager):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
handler = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=mock_queue_manager,
|
||||
app_id=None,
|
||||
message_id=None,
|
||||
user_id=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
handler.on_query(None, None)
|
||||
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestOnToolEnd:
|
||||
def test_on_tool_end_no_metadata(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = None
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
def test_on_tool_end_dataset_document_not_found(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
def test_on_tool_end_parent_child_index_with_child(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
mock_dataset_doc = mocker.Mock()
|
||||
from core.callback_handler.index_tool_callback_handler import IndexStructureType
|
||||
|
||||
mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
mock_dataset_doc.dataset_id = "dataset-1"
|
||||
mock_dataset_doc.id = "doc-1"
|
||||
|
||||
mock_child_chunk = mocker.Mock()
|
||||
mock_child_chunk.segment_id = "segment-1"
|
||||
|
||||
mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk]
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
mock_dataset_doc = mocker.Mock()
|
||||
mock_dataset_doc.doc_form = "OTHER"
|
||||
|
||||
mock_db.session.scalar.return_value = mock_dataset_doc
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = {
|
||||
"document_id": "doc-1",
|
||||
"doc_id": "node-1",
|
||||
"dataset_id": "dataset-1",
|
||||
}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_empty_documents(self, handler):
|
||||
handler.on_tool_end([])
|
||||
|
||||
|
||||
class TestReturnRetrieverResourceInfo:
|
||||
def test_publish_called(self, handler, mock_queue_manager, mocker):
|
||||
mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent")
|
||||
|
||||
resources = [mocker.Mock()]
|
||||
|
||||
handler.return_retriever_resource_info(resources)
|
||||
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
@@ -1,184 +0,0 @@
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import (
|
||||
DifyWorkflowCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
class DummyToolInvokeMessage:
|
||||
"""Lightweight dummy to simulate ToolInvokeMessage behavior."""
|
||||
|
||||
def __init__(self, json_value: str):
|
||||
self._json_value = json_value
|
||||
|
||||
def model_dump_json(self):
|
||||
return self._json_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
"""Fixture to create handler instance with deterministic color."""
|
||||
instance = DifyWorkflowCallbackHandler()
|
||||
instance.color = "blue"
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_print_text(mocker):
|
||||
"""Mock print_text to avoid real stdout printing."""
|
||||
return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text")
|
||||
|
||||
|
||||
class TestDifyWorkflowCallbackHandler:
|
||||
def test_on_tool_execution_single_output_success(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "test_tool"
|
||||
tool_inputs = {"a": 1}
|
||||
message = DummyToolInvokeMessage('{"key": "value"}')
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=[message],
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert results == [message]
|
||||
assert mock_print_text.call_count == 4
|
||||
mock_print_text.assert_has_calls(
|
||||
[
|
||||
call("\n[on_tool_execution]\n", color="blue"),
|
||||
call("Tool: test_tool\n", color="blue"),
|
||||
call(
|
||||
"Outputs: " + message.model_dump_json()[:1000] + "\n",
|
||||
color="blue",
|
||||
),
|
||||
call("\n"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "multi_tool"
|
||||
outputs = [
|
||||
DummyToolInvokeMessage('{"id": 1}'),
|
||||
DummyToolInvokeMessage('{"id": 2}'),
|
||||
]
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=outputs,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert results == outputs
|
||||
assert mock_print_text.call_count == 4 * len(outputs)
|
||||
|
||||
def test_on_tool_execution_empty_iterable(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "empty_tool"
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert results == []
|
||||
mock_print_text.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invalid_outputs", "expected_exception"),
|
||||
[
|
||||
(None, TypeError),
|
||||
(123, TypeError),
|
||||
("not_iterable", AttributeError),
|
||||
],
|
||||
)
|
||||
def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception):
|
||||
# Arrange
|
||||
tool_name = "invalid_tool"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(expected_exception):
|
||||
list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=invalid_outputs,
|
||||
)
|
||||
)
|
||||
|
||||
def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "long_json_tool"
|
||||
long_json = "x" * 1500
|
||||
message = DummyToolInvokeMessage(long_json)
|
||||
|
||||
# Act
|
||||
list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[message],
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected_truncated = long_json[:1000]
|
||||
mock_print_text.assert_any_call(
|
||||
"Outputs: " + expected_truncated + "\n",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "exception_tool"
|
||||
bad_message = MagicMock()
|
||||
bad_message.model_dump_json.side_effect = ValueError("JSON error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[bad_message],
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure first two prints happened before failure
|
||||
assert mock_print_text.call_count >= 2
|
||||
|
||||
def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "optional_params_tool"
|
||||
message = DummyToolInvokeMessage('{"data": "ok"}')
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[message],
|
||||
message_id=None,
|
||||
timer=None,
|
||||
trace_manager=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert results == [message]
|
||||
assert mock_print_text.call_count == 4
|
||||
@@ -1,9 +0,0 @@
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
|
||||
|
||||
def test_planning_strategy_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert PlanningStrategy.ROUTER.value == "router"
|
||||
assert PlanningStrategy.REACT_ROUTER.value == "react_router"
|
||||
assert PlanningStrategy.REACT.value == "react"
|
||||
assert PlanningStrategy.FUNCTION_CALL.value == "function_call"
|
||||
@@ -1,18 +0,0 @@
|
||||
from core.entities.document_task import DocumentTask
|
||||
|
||||
|
||||
def test_document_task_keeps_indexing_identifiers() -> None:
|
||||
# Arrange
|
||||
document_ids = ("doc-1", "doc-2")
|
||||
|
||||
# Act
|
||||
task = DocumentTask(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert task.tenant_id == "tenant-1"
|
||||
assert task.dataset_id == "dataset-1"
|
||||
assert task.document_ids == document_ids
|
||||
@@ -1,7 +0,0 @@
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
|
||||
|
||||
def test_embedding_input_type_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert EmbeddingInputType.DOCUMENT.value == "document"
|
||||
assert EmbeddingInputType.QUERY.value == "query"
|
||||
@@ -1,45 +0,0 @@
|
||||
from core.entities.execution_extra_content import (
|
||||
ExecutionExtraContentDomainModel,
|
||||
HumanInputContent,
|
||||
HumanInputFormDefinition,
|
||||
HumanInputFormSubmissionData,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import FormInputType
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
def test_human_input_content_defaults_and_domain_alias() -> None:
|
||||
# Arrange
|
||||
form_definition = HumanInputFormDefinition(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="Please confirm",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")],
|
||||
actions=[UserAction(id="confirm", title="Confirm")],
|
||||
resolved_default_values={"answer": "yes"},
|
||||
expiration_time=1_700_000_000,
|
||||
)
|
||||
submission_data = HumanInputFormSubmissionData(
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="Please confirm",
|
||||
action_id="confirm",
|
||||
action_text="Confirm",
|
||||
)
|
||||
|
||||
# Act
|
||||
content = HumanInputContent(
|
||||
workflow_run_id="workflow-run-1",
|
||||
submitted=True,
|
||||
form_definition=form_definition,
|
||||
form_submission_data=submission_data,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert form_definition.model_config.get("frozen") is True
|
||||
assert content.type == ExecutionContentType.HUMAN_INPUT
|
||||
assert content.form_definition is form_definition
|
||||
assert content.form_submission_data is submission_data
|
||||
assert ExecutionExtraContentDomainModel is HumanInputContent
|
||||
@@ -1,45 +0,0 @@
|
||||
from core.entities.knowledge_entities import (
|
||||
PipelineDataset,
|
||||
PipelineDocument,
|
||||
PipelineGenerateResponse,
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_dataset_normalizes_none_description() -> None:
|
||||
# Arrange / Act
|
||||
dataset = PipelineDataset(
|
||||
id="dataset-1",
|
||||
name="Dataset",
|
||||
description=None,
|
||||
chunk_structure="parent-child",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert dataset.description == ""
|
||||
|
||||
|
||||
def test_pipeline_generate_response_builds_nested_models() -> None:
|
||||
# Arrange
|
||||
dataset = PipelineDataset(
|
||||
id="dataset-1",
|
||||
name="Dataset",
|
||||
description="Knowledge base",
|
||||
chunk_structure="parent-child",
|
||||
)
|
||||
document = PipelineDocument(
|
||||
id="doc-1",
|
||||
position=1,
|
||||
data_source_type="file",
|
||||
data_source_info={"name": "spec.pdf"},
|
||||
name="spec.pdf",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document])
|
||||
|
||||
# Assert
|
||||
assert response.batch == "batch-1"
|
||||
assert response.dataset.id == "dataset-1"
|
||||
assert response.documents[0].id == "doc-1"
|
||||
@@ -1,450 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities import mcp_provider as mcp_provider_module
|
||||
from core.entities.mcp_provider import (
|
||||
DEFAULT_EXPIRES_IN,
|
||||
DEFAULT_TOKEN_TYPE,
|
||||
MCPProviderEntity,
|
||||
)
|
||||
from core.mcp.types import OAuthTokens
|
||||
|
||||
|
||||
def _build_mcp_provider_entity() -> MCPProviderEntity:
|
||||
now = datetime(2025, 1, 1, tzinfo=UTC)
|
||||
return MCPProviderEntity(
|
||||
id="provider-1",
|
||||
provider_id="server-1",
|
||||
name="Example MCP",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
server_url="encrypted-server-url",
|
||||
headers={},
|
||||
timeout=30,
|
||||
sse_read_timeout=300,
|
||||
authed=False,
|
||||
credentials={},
|
||||
tools=[],
|
||||
icon={"en_US": "icon.png"},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_model_maps_fields() -> None:
|
||||
# Arrange
|
||||
now = datetime(2025, 1, 1, tzinfo=UTC)
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
server_identifier="server-1",
|
||||
name="Example MCP",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
server_url="encrypted-server-url",
|
||||
headers={"Authorization": "enc"},
|
||||
timeout=15,
|
||||
sse_read_timeout=120,
|
||||
authed=True,
|
||||
credentials={"access_token": "enc-token"},
|
||||
tool_dict=[{"name": "search"}],
|
||||
icon=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Act
|
||||
entity = MCPProviderEntity.from_db_model(db_provider)
|
||||
|
||||
# Assert
|
||||
assert entity.provider_id == "server-1"
|
||||
assert entity.tools == [{"name": "search"}]
|
||||
assert entity.icon == ""
|
||||
|
||||
|
||||
def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com")
|
||||
|
||||
# Act
|
||||
redirect_url = entity.redirect_url
|
||||
|
||||
# Assert
|
||||
assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback"
|
||||
|
||||
|
||||
def test_client_metadata_for_authorization_code_flow() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
|
||||
# Act
|
||||
metadata = entity.client_metadata
|
||||
|
||||
# Assert
|
||||
assert metadata.grant_types == ["refresh_token", "authorization_code"]
|
||||
assert metadata.redirect_uris == [entity.redirect_url]
|
||||
assert metadata.response_types == ["code"]
|
||||
|
||||
|
||||
def test_client_metadata_for_client_credentials_flow() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {"client_information": {"grant_types": ["client_credentials"]}}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
# Act
|
||||
metadata = entity.client_metadata
|
||||
|
||||
# Assert
|
||||
assert metadata.grant_types == ["refresh_token", "client_credentials"]
|
||||
assert metadata.redirect_uris == []
|
||||
assert metadata.response_types == []
|
||||
|
||||
|
||||
def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_information": {"grant_types": ["authorization_code"]},
|
||||
}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
# Act
|
||||
metadata = entity.client_metadata
|
||||
|
||||
# Assert
|
||||
assert metadata.grant_types == ["refresh_token", "authorization_code"]
|
||||
assert metadata.redirect_uris == [entity.redirect_url]
|
||||
assert metadata.response_types == ["code"]
|
||||
|
||||
|
||||
def test_provider_icon_returns_icon_dict_as_is() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
|
||||
|
||||
# Act
|
||||
icon = entity.provider_icon
|
||||
|
||||
# Assert
|
||||
assert icon == {"en_US": "icon.png"}
|
||||
|
||||
|
||||
def test_provider_icon_uses_signed_url_for_plain_path() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"})
|
||||
|
||||
with patch(
|
||||
"core.entities.mcp_provider.file_helpers.get_signed_file_url",
|
||||
return_value="https://signed.example.com/icons/mcp.png",
|
||||
) as mock_get_signed_url:
|
||||
# Act
|
||||
icon = entity.provider_icon
|
||||
|
||||
# Assert
|
||||
mock_get_signed_url.assert_called_once_with("icons/mcp.png")
|
||||
assert icon == "https://signed.example.com/icons/mcp.png"
|
||||
|
||||
|
||||
def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
|
||||
|
||||
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
|
||||
# Act
|
||||
response = entity.to_api_response(include_sensitive=False)
|
||||
|
||||
# Assert
|
||||
assert response["author"] == "Anonymous"
|
||||
assert response["masked_headers"] == {}
|
||||
assert response["is_dynamic_registration"] is True
|
||||
assert "authentication" not in response
|
||||
|
||||
|
||||
def test_to_api_response_with_sensitive_data_includes_masked_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(
|
||||
update={
|
||||
"credentials": {"client_information": {"is_dynamic_registration": False}},
|
||||
"icon": {"en_US": "icon.png"},
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
|
||||
with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}):
|
||||
with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}):
|
||||
# Act
|
||||
response = entity.to_api_response(user_name="Rajat", include_sensitive=True)
|
||||
|
||||
# Assert
|
||||
assert response["author"] == "Rajat"
|
||||
assert response["masked_headers"] == {"Authorization": "Be****"}
|
||||
assert response["authentication"] == {"client_id": "cl****"}
|
||||
assert response["is_dynamic_registration"] is False
|
||||
|
||||
|
||||
def test_retrieve_client_information_decrypts_nested_secret() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt:
|
||||
# Act
|
||||
client_info = entity.retrieve_client_information()
|
||||
|
||||
# Assert
|
||||
assert client_info is not None
|
||||
assert client_info.client_id == "client-1"
|
||||
assert client_info.client_secret == "plain-secret"
|
||||
mock_decrypt.assert_called_once_with("tenant-1", "enc-secret")
|
||||
|
||||
|
||||
def test_retrieve_client_information_returns_none_for_missing_data() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
|
||||
# Act
|
||||
result_empty = entity.retrieve_client_information()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
|
||||
# Act
|
||||
result_invalid = entity.retrieve_client_information()
|
||||
|
||||
# Assert
|
||||
assert result_empty is None
|
||||
assert result_invalid is None
|
||||
|
||||
|
||||
def test_masked_server_url_hides_path_segments() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"decrypt_server_url",
|
||||
return_value="https://api.example.com/v1/mcp?query=1",
|
||||
):
|
||||
# Act
|
||||
masked_url = entity.masked_server_url()
|
||||
|
||||
# Assert
|
||||
assert masked_url == "https://api.example.com/******?query=1"
|
||||
|
||||
|
||||
def test_mask_value_covers_short_and_long_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
# Act
|
||||
short_masked = entity._mask_value("short")
|
||||
long_masked = entity._mask_value("abcdefghijkl")
|
||||
|
||||
# Assert
|
||||
assert short_masked == "*****"
|
||||
assert long_masked == "ab********kl"
|
||||
|
||||
|
||||
def test_masked_headers_masks_all_decrypted_header_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}):
|
||||
# Act
|
||||
masked = entity.masked_headers()
|
||||
|
||||
# Assert
|
||||
assert masked == {"Authorization": "ab****gh"}
|
||||
|
||||
|
||||
def test_masked_credentials_handles_nested_secret_fields() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {
|
||||
"client_information": {
|
||||
"client_id": "client-id",
|
||||
"encrypted_client_secret": "encrypted-value",
|
||||
"client_secret": "plain-secret",
|
||||
}
|
||||
}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"):
|
||||
# Act
|
||||
masked = entity.masked_credentials()
|
||||
|
||||
# Assert
|
||||
assert masked["client_id"] == "cl*****id"
|
||||
assert masked["client_secret"] == "pl********et"
|
||||
|
||||
|
||||
def test_masked_credentials_returns_empty_for_missing_client_information() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
|
||||
# Act
|
||||
masked_empty = entity.masked_credentials()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
|
||||
# Act
|
||||
masked_invalid = entity.masked_credentials()
|
||||
|
||||
# Assert
|
||||
assert masked_empty == {}
|
||||
assert masked_invalid == {}
|
||||
|
||||
|
||||
def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
|
||||
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"decrypt_credentials",
|
||||
return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"},
|
||||
):
|
||||
# Act
|
||||
tokens = entity.retrieve_tokens()
|
||||
|
||||
# Assert
|
||||
assert isinstance(tokens, OAuthTokens)
|
||||
assert tokens.access_token == "token"
|
||||
assert tokens.token_type == DEFAULT_TOKEN_TYPE
|
||||
assert tokens.expires_in == DEFAULT_EXPIRES_IN
|
||||
assert tokens.refresh_token == "refresh"
|
||||
|
||||
|
||||
def test_retrieve_tokens_returns_none_when_access_token_missing() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt:
|
||||
# Act
|
||||
tokens = entity.retrieve_tokens()
|
||||
|
||||
# Assert
|
||||
mock_decrypt.assert_called_once()
|
||||
assert tokens is None
|
||||
|
||||
|
||||
def test_decrypt_server_url_delegates_to_encrypter() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock:
|
||||
# Act
|
||||
decrypted = entity.decrypt_server_url()
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once_with("tenant-1", "encrypted-server-url")
|
||||
assert decrypted == "https://api.example.com"
|
||||
|
||||
|
||||
def test_decrypt_authentication_injects_authorization_for_oauth() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}})
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}):
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"retrieve_tokens",
|
||||
return_value=OAuthTokens(access_token="abc123", token_type="bearer"),
|
||||
):
|
||||
# Act
|
||||
headers = entity.decrypt_authentication()
|
||||
|
||||
# Assert
|
||||
assert headers["Authorization"] == "Bearer abc123"
|
||||
|
||||
|
||||
def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(
|
||||
update={"authed": True, "headers": {"Authorization": "encrypted-header"}}
|
||||
)
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}):
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"retrieve_tokens",
|
||||
return_value=OAuthTokens(access_token="abc", token_type="bearer"),
|
||||
) as mock_tokens:
|
||||
# Act
|
||||
headers = entity.decrypt_authentication()
|
||||
|
||||
# Assert
|
||||
mock_tokens.assert_not_called()
|
||||
assert headers == {"Authorization": "existing"}
|
||||
|
||||
|
||||
def test_decrypt_dict_returns_empty_for_empty_input() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
# Act
|
||||
decrypted = entity._decrypt_dict({})
|
||||
|
||||
# Assert
|
||||
assert decrypted == {}
|
||||
|
||||
|
||||
def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""}
|
||||
|
||||
# Act
|
||||
result = entity._decrypt_dict(input_data)
|
||||
|
||||
# Assert
|
||||
assert result is input_data
|
||||
|
||||
|
||||
def test_decrypt_dict_only_decrypts_top_level_string_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
decryptor = Mock()
|
||||
decryptor.decrypt.return_value = {"api_key": "plain-key"}
|
||||
|
||||
def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache):
|
||||
assert tenant_id == "tenant-1"
|
||||
assert any(item.name == "api_key" for item in config)
|
||||
return decryptor, None
|
||||
|
||||
with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter):
|
||||
# Act
|
||||
result = entity._decrypt_dict(
|
||||
{
|
||||
"api_key": "encrypted-key",
|
||||
"nested": {"client_id": "unchanged"},
|
||||
"empty": "",
|
||||
"count": 2,
|
||||
}
|
||||
)
|
||||
|
||||
# Assert
|
||||
decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"})
|
||||
assert result["api_key"] == "plain-key"
|
||||
assert result["nested"] == {"client_id": "unchanged"}
|
||||
assert result["count"] == 2
|
||||
|
||||
|
||||
def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock:
|
||||
# Act
|
||||
headers = entity.decrypt_headers()
|
||||
credentials = entity.decrypt_credentials()
|
||||
|
||||
# Assert
|
||||
assert mock.call_count == 2
|
||||
assert headers == {"h": "v"}
|
||||
assert credentials == {"c": "v"}
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Unit tests for model entity behavior and invariants.
|
||||
|
||||
Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus,
|
||||
ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n
|
||||
labels are provided via I18nObject, model metadata aligns with FetchFrom and
|
||||
ModelType expectations, and ProviderEntity/ConfigurateMethod interactions
|
||||
drive provider mapping behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.model_entities import (
|
||||
DefaultModelEntity,
|
||||
DefaultModelProviderEntity,
|
||||
ModelStatus,
|
||||
ProviderModelWithStatusEntity,
|
||||
SimpleModelProviderEntity,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
|
||||
|
||||
def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity:
|
||||
return ProviderModelWithStatusEntity(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
|
||||
# Arrange
|
||||
provider_entity = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
# Act
|
||||
simple_provider = SimpleModelProviderEntity(provider_entity)
|
||||
|
||||
# Assert
|
||||
assert simple_provider.provider == "openai"
|
||||
assert simple_provider.label.en_US == "OpenAI"
|
||||
assert simple_provider.supported_model_types == [ModelType.LLM]
|
||||
|
||||
|
||||
def test_provider_model_with_status_raises_for_known_error_statuses() -> None:
|
||||
# Arrange
|
||||
expectations = {
|
||||
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||
ModelStatus.DISABLED: "Model is disabled",
|
||||
}
|
||||
|
||||
for status, message in expectations.items():
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_build_model_with_status(status).raise_for_status()
|
||||
|
||||
|
||||
def test_provider_model_with_status_allows_active_and_credential_removed() -> None:
|
||||
# Arrange
|
||||
active_model = _build_model_with_status(ModelStatus.ACTIVE)
|
||||
removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED)
|
||||
|
||||
# Act / Assert
|
||||
active_model.raise_for_status()
|
||||
removed_model.raise_for_status()
|
||||
|
||||
|
||||
def test_default_model_entity_accepts_model_field_name() -> None:
|
||||
# Arrange / Act
|
||||
default_model = DefaultModelEntity(
|
||||
model="gpt-4o-mini",
|
||||
model_type=ModelType.LLM,
|
||||
provider=DefaultModelProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert default_model.model == "gpt-4o-mini"
|
||||
assert default_model.provider.provider == "openai"
|
||||
@@ -1,22 +0,0 @@
|
||||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
|
||||
|
||||
def test_common_parameter_type_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert CommonParameterType.SECRET_INPUT.value == "secret-input"
|
||||
assert CommonParameterType.MODEL_SELECTOR.value == "model-selector"
|
||||
assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select"
|
||||
assert CommonParameterType.ARRAY.value == "array"
|
||||
assert CommonParameterType.OBJECT.value == "object"
|
||||
|
||||
|
||||
def test_selector_scope_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert AppSelectorScope.WORKFLOW.value == "workflow"
|
||||
assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding"
|
||||
assert ToolSelectorScope.BUILTIN.value == "builtin"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,72 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope
|
||||
from core.entities.provider_entities import (
|
||||
BasicProviderConfig,
|
||||
ModelSettings,
|
||||
ProviderConfig,
|
||||
ProviderQuotaType,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
def test_provider_quota_type_value_of_returns_enum_member() -> None:
|
||||
# Arrange / Act
|
||||
quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value)
|
||||
|
||||
# Assert
|
||||
assert quota_type == ProviderQuotaType.TRIAL
|
||||
|
||||
|
||||
def test_provider_quota_type_value_of_rejects_unknown_values() -> None:
|
||||
# Arrange / Act / Assert
|
||||
with pytest.raises(ValueError, match="No matching enum found"):
|
||||
ProviderQuotaType.value_of("enterprise")
|
||||
|
||||
|
||||
def test_basic_provider_config_type_value_of_handles_known_values() -> None:
|
||||
# Arrange / Act
|
||||
parameter_type = BasicProviderConfig.Type.value_of("text-input")
|
||||
|
||||
# Assert
|
||||
assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT
|
||||
|
||||
|
||||
def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None:
|
||||
# Arrange / Act / Assert
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
BasicProviderConfig.Type.value_of("unknown")
|
||||
|
||||
|
||||
def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None:
|
||||
# Arrange
|
||||
provider_config = ProviderConfig(
|
||||
type=BasicProviderConfig.Type.SELECT,
|
||||
name="workspace",
|
||||
scope=AppSelectorScope.ALL,
|
||||
options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))],
|
||||
)
|
||||
|
||||
# Act
|
||||
basic_config = provider_config.to_basic_provider_config()
|
||||
|
||||
# Assert
|
||||
assert isinstance(basic_config, BasicProviderConfig)
|
||||
assert basic_config.type == BasicProviderConfig.Type.SELECT
|
||||
assert basic_config.name == "workspace"
|
||||
|
||||
|
||||
def test_model_settings_accepts_model_field_name() -> None:
|
||||
# Arrange / Act
|
||||
settings = ModelSettings(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
load_balancing_configs=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert settings.model == "gpt-4o"
|
||||
assert settings.model_type == ModelType.LLM
|
||||
@@ -1,137 +0,0 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
def test_request_success(mocker):
|
||||
# Mock httpx.Client and its context manager
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
mocker.patch("httpx.Client", return_value=mock_client)
|
||||
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
method="POST",
|
||||
url="http://example.com",
|
||||
json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}},
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"},
|
||||
)
|
||||
|
||||
|
||||
def test_request_with_ssrf_proxy(mocker):
|
||||
# Mock dify_config
|
||||
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
|
||||
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081")
|
||||
|
||||
# Mock httpx.Client
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
|
||||
# Mock response
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
|
||||
# Mock HTTPTransport
|
||||
mock_transport = mocker.patch("httpx.HTTPTransport")
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||
|
||||
# Verify httpx.Client was called with mounts
|
||||
mock_client_class.assert_called_once()
|
||||
kwargs = mock_client_class.call_args.kwargs
|
||||
assert "mounts" in kwargs
|
||||
assert "http://" in kwargs["mounts"]
|
||||
assert "https://" in kwargs["mounts"]
|
||||
assert mock_transport.call_count == 2
|
||||
|
||||
|
||||
def test_request_with_only_one_proxy_config(mocker):
|
||||
# Mock dify_config with only one proxy
|
||||
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
|
||||
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None)
|
||||
|
||||
# Mock httpx.Client
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
|
||||
# Mock response
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||
|
||||
# Verify httpx.Client was called with mounts=None (default)
|
||||
mock_client_class.assert_called_once()
|
||||
kwargs = mock_client_class.call_args.kwargs
|
||||
assert kwargs.get("mounts") is None
|
||||
|
||||
|
||||
def test_request_timeout(mocker):
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
mocker.patch("httpx.Client", return_value=mock_client)
|
||||
mock_client_instance.request.side_effect = httpx.TimeoutException("timeout")
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
with pytest.raises(ValueError, match="request timeout"):
|
||||
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||
|
||||
|
||||
def test_request_connection_error(mocker):
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
mocker.patch("httpx.Client", return_value=mock_client)
|
||||
mock_client_instance.request.side_effect = httpx.RequestError("error")
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
with pytest.raises(ValueError, match="request connection error"):
|
||||
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||
|
||||
|
||||
def test_request_error_status_code(mocker):
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
mocker.patch("httpx.Client", return_value=mock_client)
|
||||
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = "Not Found"
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"):
|
||||
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||
|
||||
|
||||
def test_request_error_status_code_long_content(mocker):
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client_instance = mock_client.__enter__.return_value
|
||||
mocker.patch("httpx.Client", return_value=mock_client)
|
||||
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "A" * 200 # Testing truncation of content
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
|
||||
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||
expected_content = "A" * 100
|
||||
with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"):
|
||||
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||
@@ -1,281 +0,0 @@
|
||||
import json
|
||||
import types
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.extension.extensible import Extensible
|
||||
|
||||
|
||||
class TestExtensible:
|
||||
def test_init(self):
|
||||
tenant_id = "tenant_123"
|
||||
config = {"key": "value"}
|
||||
ext = Extensible(tenant_id, config)
|
||||
assert ext.tenant_id == tenant_id
|
||||
assert ext.config == config
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
@patch("core.extension.extensible.os.path.exists")
|
||||
@patch("core.extension.extensible.Path.read_text")
|
||||
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||
@patch("core.extension.extensible.sort_to_dict_by_position_map")
|
||||
def test_scan_extensions_success(
|
||||
self,
|
||||
mock_sort,
|
||||
mock_module_from_spec,
|
||||
mock_read_text,
|
||||
mock_exists,
|
||||
mock_isdir,
|
||||
mock_listdir,
|
||||
mock_dirname,
|
||||
mock_find_spec,
|
||||
):
|
||||
# Setup
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
|
||||
module_spec = MagicMock()
|
||||
module_spec.loader = MagicMock()
|
||||
|
||||
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
|
||||
mock_listdir.side_effect = [
|
||||
["ext1"], # package_dir
|
||||
["ext1.py", "__builtin__"], # subdir_path
|
||||
]
|
||||
mock_isdir.return_value = True
|
||||
|
||||
mock_exists.return_value = True
|
||||
mock_read_text.return_value = "10"
|
||||
|
||||
# Use types.ModuleType to avoid MagicMock __dict__ issues
|
||||
mock_mod = types.ModuleType("ext1")
|
||||
|
||||
class MockExtension(Extensible):
|
||||
pass
|
||||
|
||||
mock_mod.MockExtension = MockExtension
|
||||
mock_module_from_spec.return_value = mock_mod
|
||||
|
||||
mock_sort.side_effect = lambda position_map, data, name_func: data
|
||||
|
||||
# Execute
|
||||
results = Extensible.scan_extensions()
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "ext1"
|
||||
assert results[0].position == 10
|
||||
assert results[0].builtin is True
|
||||
assert results[0].extension_class == MockExtension
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
def test_scan_extensions_package_not_found(self, mock_find_spec):
|
||||
mock_find_spec.return_value = None
|
||||
with pytest.raises(ImportError, match="Could not find package"):
|
||||
Extensible.scan_extensions()
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
mock_find_spec.return_value = package_spec
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
|
||||
mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []]
|
||||
|
||||
mock_isdir.side_effect = [False, True]
|
||||
|
||||
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
|
||||
results = Extensible.scan_extensions()
|
||||
assert len(results) == 0
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
@patch("core.extension.extensible.os.path.exists")
|
||||
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||
def test_scan_extensions_not_builtin_success(
|
||||
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||
):
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
|
||||
module_spec = MagicMock()
|
||||
module_spec.loader = MagicMock()
|
||||
|
||||
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
|
||||
mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]]
|
||||
mock_isdir.return_value = True
|
||||
|
||||
# exists checks: only schema.json needs to exist
|
||||
mock_exists.return_value = True
|
||||
|
||||
mock_mod = types.ModuleType("ext1")
|
||||
|
||||
class MockExtension(Extensible):
|
||||
pass
|
||||
|
||||
mock_mod.MockExtension = MockExtension
|
||||
mock_module_from_spec.return_value = mock_mod
|
||||
|
||||
schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]})
|
||||
|
||||
with (
|
||||
patch("builtins.open", mock_open(read_data=schema_content)),
|
||||
patch(
|
||||
"core.extension.extensible.sort_to_dict_by_position_map",
|
||||
side_effect=lambda position_map, data, name_func: data,
|
||||
),
|
||||
):
|
||||
results = Extensible.scan_extensions()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "ext1"
|
||||
assert results[0].builtin is False
|
||||
assert results[0].label == {"en": "Test"}
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
@patch("core.extension.extensible.os.path.exists")
|
||||
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||
def test_scan_extensions_not_builtin_missing_schema(
|
||||
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||
):
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
|
||||
module_spec = MagicMock()
|
||||
module_spec.loader = MagicMock()
|
||||
|
||||
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
|
||||
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
|
||||
mock_isdir.return_value = True
|
||||
|
||||
# exists: only schema.json checked, and return False
|
||||
mock_exists.return_value = False
|
||||
|
||||
mock_mod = types.ModuleType("ext1")
|
||||
|
||||
class MockExtension(Extensible):
|
||||
pass
|
||||
|
||||
mock_mod.MockExtension = MockExtension
|
||||
mock_module_from_spec.return_value = mock_mod
|
||||
|
||||
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
|
||||
results = Extensible.scan_extensions()
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||
@patch("core.extension.extensible.os.path.exists")
|
||||
def test_scan_extensions_no_extension_class(
|
||||
self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||
):
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
module_spec = MagicMock()
|
||||
module_spec.loader = MagicMock()
|
||||
|
||||
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
|
||||
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
|
||||
mock_isdir.return_value = True
|
||||
|
||||
# Mock not builtin
|
||||
mock_exists.return_value = False
|
||||
|
||||
mock_mod = types.ModuleType("ext1")
|
||||
mock_mod.SomeOtherClass = type("SomeOtherClass", (), {})
|
||||
mock_module_from_spec.return_value = mock_mod
|
||||
|
||||
# We need to ensure we don't crash if checking schema (but we won't reach there because class not found)
|
||||
|
||||
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
|
||||
results = Extensible.scan_extensions()
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
|
||||
mock_find_spec.side_effect = [package_spec, None] # No module spec
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
|
||||
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
|
||||
mock_isdir.return_value = True
|
||||
|
||||
with pytest.raises(ImportError, match="Failed to load module"):
|
||||
Extensible.scan_extensions()
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
def test_scan_extensions_general_exception(self, mock_find_spec):
|
||||
mock_find_spec.side_effect = Exception("Unexpected error")
|
||||
with pytest.raises(Exception, match="Unexpected error"):
|
||||
Extensible.scan_extensions()
|
||||
|
||||
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||
@patch("core.extension.extensible.os.path.dirname")
|
||||
@patch("core.extension.extensible.os.listdir")
|
||||
@patch("core.extension.extensible.os.path.isdir")
|
||||
@patch("core.extension.extensible.os.path.exists")
|
||||
@patch("core.extension.extensible.Path.read_text")
|
||||
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||
def test_scan_extensions_builtin_without_position_file(
|
||||
self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||
):
|
||||
package_spec = MagicMock()
|
||||
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||
module_spec = MagicMock()
|
||||
module_spec.loader = MagicMock()
|
||||
|
||||
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||
mock_dirname.return_value = "/path/to/pkg"
|
||||
mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]]
|
||||
mock_isdir.return_value = True
|
||||
|
||||
# builtin exists in listdir, but os.path.exists(builtin_file_path) returns False
|
||||
mock_exists.return_value = False
|
||||
|
||||
mock_mod = types.ModuleType("ext1")
|
||||
|
||||
class MockExtension(Extensible):
|
||||
pass
|
||||
|
||||
mock_mod.MockExtension = MockExtension
|
||||
mock_module_from_spec.return_value = mock_mod
|
||||
|
||||
with patch(
|
||||
"core.extension.extensible.sort_to_dict_by_position_map",
|
||||
side_effect=lambda position_map, data, name_func: data,
|
||||
):
|
||||
results = Extensible.scan_extensions()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].position == 0
|
||||
@@ -1,90 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.extension.extensible import ExtensionModule, ModuleExtension
|
||||
from core.extension.extension import Extension
|
||||
|
||||
|
||||
class TestExtension:
|
||||
def setup_method(self):
|
||||
# Reset the private class attribute before each test
|
||||
Extension._Extension__module_extensions = {}
|
||||
|
||||
def test_init(self):
|
||||
# Mock scan_extensions for Moderation and ExternalDataTool
|
||||
mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")}
|
||||
mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")}
|
||||
|
||||
extension = Extension()
|
||||
|
||||
# We need to mock scan_extensions on the classes defined in Extension.module_classes
|
||||
with (
|
||||
patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions),
|
||||
patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions),
|
||||
):
|
||||
extension.init()
|
||||
|
||||
# Check if internal state is updated
|
||||
internal_state = Extension._Extension__module_extensions
|
||||
assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions
|
||||
assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions
|
||||
|
||||
def test_module_extensions_success(self):
|
||||
# Setup data
|
||||
mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")}
|
||||
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions}
|
||||
|
||||
extension = Extension()
|
||||
result = extension.module_extensions(ExtensionModule.MODERATION.value)
|
||||
|
||||
assert len(result) == 2
|
||||
assert any(e.name == "name1" for e in result)
|
||||
assert any(e.name == "name2" for e in result)
|
||||
|
||||
def test_module_extensions_not_found(self):
|
||||
extension = Extension()
|
||||
with pytest.raises(ValueError, match="Extension Module unknown not found"):
|
||||
extension.module_extensions("unknown")
|
||||
|
||||
def test_module_extension_success(self):
|
||||
mock_ext = ModuleExtension(name="test_ext")
|
||||
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
|
||||
|
||||
extension = Extension()
|
||||
result = extension.module_extension(ExtensionModule.MODERATION, "test_ext")
|
||||
assert result == mock_ext
|
||||
|
||||
def test_module_extension_module_not_found(self):
|
||||
extension = Extension()
|
||||
# ExtensionModule.MODERATION is "moderation"
|
||||
with pytest.raises(ValueError, match="Extension Module moderation not found"):
|
||||
extension.module_extension(ExtensionModule.MODERATION, "any")
|
||||
|
||||
def test_module_extension_extension_not_found(self):
|
||||
# We need a non-empty dict because 'if not module_extensions' in extension.py
|
||||
# returns True for an empty dict, which raises the module not found error instead.
|
||||
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}}
|
||||
|
||||
extension = Extension()
|
||||
with pytest.raises(ValueError, match="Extension unknown not found"):
|
||||
extension.module_extension(ExtensionModule.MODERATION, "unknown")
|
||||
|
||||
def test_extension_class_success(self):
|
||||
class MockClass:
|
||||
pass
|
||||
|
||||
mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass)
|
||||
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
|
||||
|
||||
extension = Extension()
|
||||
result = extension.extension_class(ExtensionModule.MODERATION, "test_ext")
|
||||
assert result == MockClass
|
||||
|
||||
def test_extension_class_none(self):
|
||||
mock_ext = ModuleExtension(name="test_ext", extension_class=None)
|
||||
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
|
||||
|
||||
extension = Extension()
|
||||
with pytest.raises(AssertionError):
|
||||
extension.extension_class(ExtensionModule.MODERATION, "test_ext")
|
||||
@@ -1,145 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.external_data_tool.api.api import ApiExternalDataTool
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
def test_api_external_data_tool_name():
|
||||
assert ApiExternalDataTool.name == "api"
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
def test_validate_config_success(mock_db):
|
||||
mock_extension = MagicMock()
|
||||
mock_extension.id = "ext_id"
|
||||
mock_extension.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = mock_extension
|
||||
|
||||
# Should not raise exception
|
||||
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
|
||||
|
||||
|
||||
def test_validate_config_missing_id():
|
||||
with pytest.raises(ValueError, match="api_based_extension_id is required"):
|
||||
ApiExternalDataTool.validate_config("tenant_id", {})
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
def test_validate_config_invalid_id(mock_db):
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_based_extension_id is invalid"):
|
||||
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_tool():
|
||||
# Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel
|
||||
return ApiExternalDataTool(
|
||||
tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"}
|
||||
)
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
@patch("core.external_data_tool.api.api.encrypter")
|
||||
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||
def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||
mock_extension = MagicMock()
|
||||
mock_extension.id = "ext_id"
|
||||
mock_extension.tenant_id = "tenant_id"
|
||||
mock_extension.api_endpoint = "http://api"
|
||||
mock_extension.api_key = "encrypted_key"
|
||||
mock_db.session.scalar.return_value = mock_extension
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||
|
||||
mock_requestor = mock_requestor_class.return_value
|
||||
mock_requestor.request.return_value = {"result": "success_result"}
|
||||
|
||||
res = api_tool.query({"input1": "value1"}, "query_str")
|
||||
|
||||
assert res == "success_result"
|
||||
|
||||
mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key")
|
||||
mock_requestor.request.assert_called_once_with(
|
||||
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
|
||||
params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"},
|
||||
)
|
||||
|
||||
|
||||
def test_query_missing_config():
|
||||
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1")
|
||||
api_tool.config = None # Force None
|
||||
with pytest.raises(ValueError, match="config is required"):
|
||||
api_tool.query({}, "")
|
||||
|
||||
|
||||
def test_query_missing_extension_id():
|
||||
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"})
|
||||
with pytest.raises(AssertionError, match="api_based_extension_id is required"):
|
||||
api_tool.query({}, "")
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
def test_query_invalid_extension(mock_db, api_tool):
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"):
|
||||
api_tool.query({}, "")
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
@patch("core.external_data_tool.api.api.encrypter")
|
||||
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||
def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||
mock_extension = MagicMock()
|
||||
mock_extension.id = "ext_id"
|
||||
mock_extension.tenant_id = "tenant_id"
|
||||
mock_extension.api_endpoint = "http://api"
|
||||
mock_extension.api_key = "encrypted_key"
|
||||
mock_db.session.scalar.return_value = mock_extension
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||
|
||||
mock_requestor_class.side_effect = Exception("init error")
|
||||
|
||||
with pytest.raises(ValueError, match=".*error: init error"):
|
||||
api_tool.query({}, "")
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
@patch("core.external_data_tool.api.api.encrypter")
|
||||
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||
def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||
mock_extension = MagicMock()
|
||||
mock_extension.id = "ext_id"
|
||||
mock_extension.tenant_id = "tenant_id"
|
||||
mock_extension.api_endpoint = "http://api"
|
||||
mock_extension.api_key = "encrypted_key"
|
||||
mock_db.session.scalar.return_value = mock_extension
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||
|
||||
mock_requestor = mock_requestor_class.return_value
|
||||
mock_requestor.request.return_value = {"other": "value"}
|
||||
|
||||
with pytest.raises(ValueError, match=".*error: result not found in response"):
|
||||
api_tool.query({}, "")
|
||||
|
||||
|
||||
@patch("core.external_data_tool.api.api.db")
|
||||
@patch("core.external_data_tool.api.api.encrypter")
|
||||
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||
def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||
mock_extension = MagicMock()
|
||||
mock_extension.id = "ext_id"
|
||||
mock_extension.tenant_id = "tenant_id"
|
||||
mock_extension.api_endpoint = "http://api"
|
||||
mock_extension.api_key = "encrypted_key"
|
||||
mock_db.session.scalar.return_value = mock_extension
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||
|
||||
mock_requestor = mock_requestor_class.return_value
|
||||
mock_requestor.request.return_value = {"result": 123} # Not a string
|
||||
|
||||
with pytest.raises(ValueError, match=".*error: result is not string"):
|
||||
api_tool.query({}, "")
|
||||
@@ -1,66 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
|
||||
|
||||
class TestExternalDataTool:
|
||||
def test_module_attribute(self):
|
||||
assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL
|
||||
|
||||
def test_init(self):
|
||||
# Create a concrete subclass to test init
|
||||
class ConcreteTool(ExternalDataTool):
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
return super().validate_config(tenant_id, config)
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
return super().query(inputs, query)
|
||||
|
||||
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"})
|
||||
assert tool.tenant_id == "tenant_1"
|
||||
assert tool.app_id == "app_1"
|
||||
assert tool.variable == "var_1"
|
||||
assert tool.config == {"key": "value"}
|
||||
|
||||
def test_init_without_config(self):
|
||||
# Create a concrete subclass to test init
|
||||
class ConcreteTool(ExternalDataTool):
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
pass
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
return ""
|
||||
|
||||
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
|
||||
assert tool.tenant_id == "tenant_1"
|
||||
assert tool.app_id == "app_1"
|
||||
assert tool.variable == "var_1"
|
||||
assert tool.config is None
|
||||
|
||||
def test_validate_config_raises_not_implemented(self):
|
||||
class ConcreteTool(ExternalDataTool):
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
return super().validate_config(tenant_id, config)
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
ConcreteTool.validate_config("tenant_1", {})
|
||||
|
||||
def test_query_raises_not_implemented(self):
|
||||
class ConcreteTool(ExternalDataTool):
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
pass
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
return super().query(inputs, query)
|
||||
|
||||
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
|
||||
with pytest.raises(NotImplementedError):
|
||||
tool.query({})
|
||||
@@ -1,115 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
|
||||
|
||||
class TestExternalDataFetch:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
def test_fetch_success(self, app):
|
||||
with app.app_context():
|
||||
fetcher = ExternalDataFetch()
|
||||
|
||||
# Setup mocks
|
||||
tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"})
|
||||
tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"})
|
||||
|
||||
external_data_tools = [tool1, tool2]
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
|
||||
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
|
||||
# Create distinct mock instances for each tool to ensure deterministic results
|
||||
# This approach is robust regardless of thread scheduling order
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
def factory_side_effect(*args, **kwargs):
|
||||
variable = kwargs.get("variable")
|
||||
mock_instance = MagicMock()
|
||||
if variable == "var1":
|
||||
mock_instance.query.return_value = "result1"
|
||||
elif variable == "var2":
|
||||
mock_instance.query.return_value = "result2"
|
||||
return mock_instance
|
||||
|
||||
MockFactory.side_effect = factory_side_effect
|
||||
|
||||
result_inputs = fetcher.fetch(
|
||||
tenant_id="tenant1",
|
||||
app_id="app1",
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# Each tool gets its deterministic result regardless of thread completion order
|
||||
assert result_inputs["var1"] == "result1"
|
||||
assert result_inputs["var2"] == "result2"
|
||||
assert result_inputs["input_key"] == "input_value"
|
||||
assert len(result_inputs) == 3
|
||||
|
||||
# Verify factory calls
|
||||
assert MockFactory.call_count == 2
|
||||
MockFactory.assert_any_call(
|
||||
name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"}
|
||||
)
|
||||
MockFactory.assert_any_call(
|
||||
name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"}
|
||||
)
|
||||
|
||||
def test_fetch_no_tools(self):
|
||||
# We don't necessarily need app_context if there are no tools,
|
||||
# but fetch calls current_app._get_current_object() only inside the loop.
|
||||
# Wait, let's look at the code.
|
||||
# for tool in external_data_tools:
|
||||
# executor.submit(..., current_app._get_current_object(), ...)
|
||||
# So if external_data_tools is empty, it shouldn't access current_app.
|
||||
fetcher = ExternalDataFetch()
|
||||
inputs = {"input_key": "input_value"}
|
||||
result_inputs = fetcher.fetch(
|
||||
tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query"
|
||||
)
|
||||
assert result_inputs == inputs
|
||||
assert result_inputs is not inputs # Should be a copy
|
||||
|
||||
def test_fetch_with_none_variable(self, app):
|
||||
with app.app_context():
|
||||
fetcher = ExternalDataFetch()
|
||||
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
|
||||
|
||||
# Patch _query_external_data_tool to return None variable
|
||||
with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query:
|
||||
mock_query.return_value = (None, "some_result")
|
||||
|
||||
result_inputs = fetcher.fetch(
|
||||
tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q"
|
||||
)
|
||||
|
||||
assert "var1" not in result_inputs
|
||||
assert result_inputs == {"in": "val"}
|
||||
|
||||
def test_query_external_data_tool(self, app):
|
||||
fetcher = ExternalDataFetch()
|
||||
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})
|
||||
|
||||
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
|
||||
mock_factory_instance = MockFactory.return_value
|
||||
mock_factory_instance.query.return_value = "query_result"
|
||||
|
||||
var, res = fetcher._query_external_data_tool(
|
||||
flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q"
|
||||
)
|
||||
|
||||
assert var == "var1"
|
||||
assert res == "query_result"
|
||||
MockFactory.assert_called_once_with(
|
||||
name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"}
|
||||
)
|
||||
mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q")
|
||||
@@ -1,58 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
|
||||
def test_external_data_tool_factory_init():
|
||||
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
|
||||
mock_extension_class = MagicMock()
|
||||
mock_code_based_extension.extension_class.return_value = mock_extension_class
|
||||
|
||||
name = "test_tool"
|
||||
tenant_id = "tenant_123"
|
||||
app_id = "app_456"
|
||||
variable = "var_v"
|
||||
config = {"key": "value"}
|
||||
|
||||
factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config)
|
||||
|
||||
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
mock_extension_class.assert_called_once_with(
|
||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||
)
|
||||
|
||||
|
||||
def test_external_data_tool_factory_validate_config():
|
||||
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
|
||||
mock_extension_class = MagicMock()
|
||||
mock_code_based_extension.extension_class.return_value = mock_extension_class
|
||||
|
||||
name = "test_tool"
|
||||
tenant_id = "tenant_123"
|
||||
config = {"key": "value"}
|
||||
|
||||
ExternalDataToolFactory.validate_config(name, tenant_id, config)
|
||||
|
||||
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
mock_extension_class.validate_config.assert_called_once_with(tenant_id, config)
|
||||
|
||||
|
||||
def test_external_data_tool_factory_query():
|
||||
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
|
||||
mock_extension_class = MagicMock()
|
||||
mock_extension_instance = MagicMock()
|
||||
mock_extension_class.return_value = mock_extension_instance
|
||||
mock_code_based_extension.extension_class.return_value = mock_extension_class
|
||||
|
||||
mock_extension_instance.query.return_value = "query_result"
|
||||
|
||||
factory = ExternalDataToolFactory("name", "tenant", "app", "var", {})
|
||||
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "search_query"
|
||||
|
||||
result = factory.query(inputs, query)
|
||||
|
||||
assert result == "query_result"
|
||||
mock_extension_instance.query.assert_called_once_with(inputs, query)
|
||||
@@ -1,103 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class TestRuleConfigGeneratorOutputParser:
|
||||
def test_get_format_instructions(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
instructions = parser.get_format_instructions()
|
||||
assert instructions == (
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
|
||||
)
|
||||
|
||||
def test_parse_success(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
text = """
|
||||
```json
|
||||
{
|
||||
"prompt": "This is a prompt",
|
||||
"variables": ["var1", "var2"],
|
||||
"opening_statement": "Hello!"
|
||||
}
|
||||
```
|
||||
"""
|
||||
result = parser.parse(text)
|
||||
assert result["prompt"] == "This is a prompt"
|
||||
assert result["variables"] == ["var1", "var2"]
|
||||
assert result["opening_statement"] == "Hello!"
|
||||
|
||||
def test_parse_invalid_json(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
text = "invalid json"
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
parser.parse(text)
|
||||
assert "Parsing text" in str(excinfo.value)
|
||||
assert "could not find json block in the output" in str(excinfo.value)
|
||||
|
||||
def test_parse_missing_keys(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
text = """
|
||||
```json
|
||||
{
|
||||
"prompt": "This is a prompt",
|
||||
"variables": ["var1", "var2"]
|
||||
}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
parser.parse(text)
|
||||
assert "expected key `opening_statement` to be present" in str(excinfo.value)
|
||||
|
||||
def test_parse_wrong_type_prompt(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
text = """
|
||||
```json
|
||||
{
|
||||
"prompt": 123,
|
||||
"variables": ["var1", "var2"],
|
||||
"opening_statement": "Hello!"
|
||||
}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
parser.parse(text)
|
||||
assert "Expected 'prompt' to be a string" in str(excinfo.value)
|
||||
|
||||
def test_parse_wrong_type_variables(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
text = """
|
||||
```json
|
||||
{
|
||||
"prompt": "This is a prompt",
|
||||
"variables": "not a list",
|
||||
"opening_statement": "Hello!"
|
||||
}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
parser.parse(text)
|
||||
assert "Expected 'variables' to be a list" in str(excinfo.value)
|
||||
|
||||
def test_parse_wrong_type_opening_statement(self):
|
||||
parser = RuleConfigGeneratorOutputParser()
|
||||
text = """
|
||||
```json
|
||||
{
|
||||
"prompt": "This is a prompt",
|
||||
"variables": ["var1", "var2"],
|
||||
"opening_statement": 123
|
||||
}
|
||||
```
|
||||
"""
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
parser.parse(text)
|
||||
assert "Expected 'opening_statement' to be a str" in str(excinfo.value)
|
||||
@@ -1,402 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
ResponseFormat,
|
||||
_handle_native_json_schema,
|
||||
_handle_prompt_based_schema,
|
||||
_parse_structured_output,
|
||||
_prepare_schema_for_model,
|
||||
_set_response_format,
|
||||
convert_boolean_to_string,
|
||||
invoke_llm_with_structured_output,
|
||||
remove_additional_properties,
|
||||
)
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType
|
||||
|
||||
|
||||
class TestStructuredOutput:
|
||||
def test_remove_additional_properties(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
"additionalProperties": False,
|
||||
"nested": {"type": "object", "additionalProperties": True},
|
||||
"items": [{"type": "object", "additionalProperties": False}],
|
||||
}
|
||||
remove_additional_properties(schema)
|
||||
assert "additionalProperties" not in schema
|
||||
assert "additionalProperties" not in schema["nested"]
|
||||
assert "additionalProperties" not in schema["items"][0]
|
||||
|
||||
# Test with non-dict input
|
||||
remove_additional_properties(None) # Should not raise
|
||||
remove_additional_properties([]) # Should not raise
|
||||
|
||||
def test_convert_boolean_to_string(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"is_active": {"type": "boolean"},
|
||||
"tags": {"type": "array", "items": {"type": "boolean"}},
|
||||
"list_schema": [{"type": "boolean"}],
|
||||
},
|
||||
}
|
||||
convert_boolean_to_string(schema)
|
||||
assert schema["properties"]["is_active"]["type"] == "string"
|
||||
assert schema["properties"]["tags"]["items"]["type"] == "string"
|
||||
assert schema["properties"]["list_schema"][0]["type"] == "string"
|
||||
|
||||
# Test with non-dict input
|
||||
convert_boolean_to_string(None) # Should not raise
|
||||
convert_boolean_to_string([]) # Should not raise
|
||||
|
||||
def test_parse_structured_output_valid(self):
|
||||
text = '{"key": "value"}'
|
||||
assert _parse_structured_output(text) == {"key": "value"}
|
||||
|
||||
def test_parse_structured_output_non_dict_valid_json(self):
|
||||
# Even if it's valid JSON, if it's not a dict, it should try repair or fail
|
||||
text = '["a", "b"]'
|
||||
with patch("json_repair.loads") as mock_repair:
|
||||
mock_repair.return_value = {"key": "value"}
|
||||
assert _parse_structured_output(text) == {"key": "value"}
|
||||
|
||||
def test_parse_structured_output_not_dict_fail_via_validate(self):
|
||||
# Force TypeAdapter to return a non-dict to trigger line 292
|
||||
with patch("pydantic.TypeAdapter.validate_json") as mock_validate:
|
||||
mock_validate.return_value = ["a list"]
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
_parse_structured_output('["a list"]')
|
||||
assert "Failed to parse structured output" in str(excinfo.value)
|
||||
|
||||
def test_parse_structured_output_repair_success(self):
|
||||
text = "{'key': 'value'}" # Invalid JSON (single quotes)
|
||||
# json_repair should handle this
|
||||
assert _parse_structured_output(text) == {"key": "value"}
|
||||
|
||||
def test_parse_structured_output_repair_list(self):
|
||||
# Deepseek-r1 case: result is a list containing a dict
|
||||
text = '[{"key": "value"}]'
|
||||
assert _parse_structured_output(text) == {"key": "value"}
|
||||
|
||||
def test_parse_structured_output_repair_list_no_dict(self):
|
||||
# Deepseek-r1 case: result is a list with NO dict
|
||||
text = "[1, 2, 3]"
|
||||
assert _parse_structured_output(text) == {}
|
||||
|
||||
def test_parse_structured_output_repair_fail(self):
|
||||
text = "not a json at all"
|
||||
with patch("json_repair.loads") as mock_repair:
|
||||
mock_repair.return_value = "still not a dict or list"
|
||||
with pytest.raises(OutputParserError):
|
||||
_parse_structured_output(text)
|
||||
|
||||
def test_set_response_format(self):
|
||||
# Test JSON
|
||||
params = {}
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label={"en_US": ""},
|
||||
type=ParameterType.STRING,
|
||||
help={"en_US": ""},
|
||||
options=[ResponseFormat.JSON],
|
||||
)
|
||||
]
|
||||
_set_response_format(params, rules)
|
||||
assert params["response_format"] == ResponseFormat.JSON
|
||||
|
||||
# Test JSON_OBJECT
|
||||
params = {}
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label={"en_US": ""},
|
||||
type=ParameterType.STRING,
|
||||
help={"en_US": ""},
|
||||
options=[ResponseFormat.JSON_OBJECT],
|
||||
)
|
||||
]
|
||||
_set_response_format(params, rules)
|
||||
assert params["response_format"] == ResponseFormat.JSON_OBJECT
|
||||
|
||||
def test_handle_native_json_schema(self):
|
||||
provider = "openai"
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.model = "gpt-4"
|
||||
structured_output_schema = {"type": "object"}
|
||||
model_parameters = {}
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label={"en_US": ""},
|
||||
type=ParameterType.STRING,
|
||||
help={"en_US": ""},
|
||||
options=[ResponseFormat.JSON_SCHEMA],
|
||||
)
|
||||
]
|
||||
|
||||
updated_params = _handle_native_json_schema(
|
||||
provider, model_schema, structured_output_schema, model_parameters, rules
|
||||
)
|
||||
|
||||
assert "json_schema" in updated_params
|
||||
assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"}
|
||||
assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA
|
||||
|
||||
def test_handle_native_json_schema_no_format_rule(self):
|
||||
provider = "openai"
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.model = "gpt-4"
|
||||
structured_output_schema = {"type": "object"}
|
||||
model_parameters = {}
|
||||
rules = []
|
||||
|
||||
updated_params = _handle_native_json_schema(
|
||||
provider, model_schema, structured_output_schema, model_parameters, rules
|
||||
)
|
||||
|
||||
assert "json_schema" in updated_params
|
||||
assert "response_format" not in updated_params
|
||||
|
||||
def test_handle_prompt_based_schema_with_system_prompt(self):
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="Existing system prompt"),
|
||||
UserPromptMessage(content="User question"),
|
||||
]
|
||||
schema = {"type": "object"}
|
||||
|
||||
result = _handle_prompt_based_schema(prompt_messages, schema)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert "Existing system prompt" in result[0].content
|
||||
assert json.dumps(schema) in result[0].content
|
||||
assert isinstance(result[1], UserPromptMessage)
|
||||
|
||||
def test_handle_prompt_based_schema_without_system_prompt(self):
|
||||
prompt_messages = [UserPromptMessage(content="User question")]
|
||||
schema = {"type": "object"}
|
||||
|
||||
result = _handle_prompt_based_schema(prompt_messages, schema)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert json.dumps(schema) in result[0].content
|
||||
assert isinstance(result[1], UserPromptMessage)
|
||||
|
||||
def test_prepare_schema_for_model_gemini(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.model = "gemini-1.5-pro"
|
||||
schema = {"type": "object", "additionalProperties": False}
|
||||
|
||||
result = _prepare_schema_for_model("google", model_schema, schema)
|
||||
assert "additionalProperties" not in result
|
||||
|
||||
def test_prepare_schema_for_model_ollama(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.model = "llama3"
|
||||
schema = {"type": "object"}
|
||||
|
||||
result = _prepare_schema_for_model("ollama", model_schema, schema)
|
||||
assert result == schema
|
||||
|
||||
def test_prepare_schema_for_model_default(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.model = "gpt-4"
|
||||
schema = {"type": "object"}
|
||||
|
||||
result = _prepare_schema_for_model("openai", model_schema, schema)
|
||||
assert result == {"schema": schema, "name": "llm_response"}
|
||||
|
||||
def test_invoke_llm_with_structured_output_no_stream_native(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = True
|
||||
model_schema.parameter_rules = [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label={"en_US": ""},
|
||||
type=ParameterType.STRING,
|
||||
help={"en_US": ""},
|
||||
options=[ResponseFormat.JSON_SCHEMA],
|
||||
)
|
||||
]
|
||||
model_schema.model = "gpt-4o"
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_result = MagicMock(spec=LLMResult)
|
||||
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
|
||||
mock_result.model = "gpt-4o"
|
||||
mock_result.usage = LLMUsage.empty_usage()
|
||||
mock_result.system_fingerprint = "fp_native"
|
||||
mock_result.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
model_instance.invoke_llm.return_value = mock_result
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
json_schema={"type": "object"},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"result": "success"}
|
||||
assert result.system_fingerprint == "fp_native"
|
||||
|
||||
def test_invoke_llm_with_structured_output_no_stream_prompt_based(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.parameter_rules = [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label={"en_US": ""},
|
||||
type=ParameterType.STRING,
|
||||
help={"en_US": ""},
|
||||
options=[ResponseFormat.JSON],
|
||||
)
|
||||
]
|
||||
model_schema.model = "claude-3"
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_result = MagicMock(spec=LLMResult)
|
||||
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
|
||||
mock_result.model = "claude-3"
|
||||
mock_result.usage = LLMUsage.empty_usage()
|
||||
mock_result.system_fingerprint = "fp_prompt"
|
||||
mock_result.prompt_messages = []
|
||||
|
||||
model_instance.invoke_llm.return_value = mock_result
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider="anthropic",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
json_schema={"type": "object"},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"result": "success"}
|
||||
assert result.system_fingerprint == "fp_prompt"
|
||||
|
||||
def test_invoke_llm_with_structured_output_no_string_error(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_result = MagicMock(spec=LLMResult)
|
||||
mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")])
|
||||
|
||||
model_instance.invoke_llm.return_value = mock_result
|
||||
|
||||
with pytest.raises(OutputParserError) as excinfo:
|
||||
invoke_llm_with_structured_output(
|
||||
provider="anthropic",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[],
|
||||
json_schema={},
|
||||
stream=False,
|
||||
)
|
||||
assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value)
|
||||
|
||||
def test_invoke_llm_with_structured_output_stream(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.parameter_rules = []
|
||||
model_schema.model = "gpt-4"
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
|
||||
# Mock chunks
|
||||
chunk1 = MagicMock(spec=LLMResultChunk)
|
||||
chunk1.delta = LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage()
|
||||
)
|
||||
chunk1.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
chunk1.system_fingerprint = "fp1"
|
||||
|
||||
chunk2 = MagicMock(spec=LLMResultChunk)
|
||||
chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}'))
|
||||
chunk2.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
chunk2.system_fingerprint = "fp1"
|
||||
|
||||
chunk3 = MagicMock(spec=LLMResultChunk)
|
||||
chunk3.delta = LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=" "),
|
||||
]
|
||||
),
|
||||
)
|
||||
chunk3.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
chunk3.system_fingerprint = "fp1"
|
||||
|
||||
event4 = MagicMock()
|
||||
event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=""))
|
||||
|
||||
model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4]
|
||||
|
||||
generator = invoke_llm_with_structured_output(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
json_schema={},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = list(generator)
|
||||
assert len(chunks) == 5
|
||||
assert chunks[-1].structured_output == {"key": "value"}
|
||||
assert chunks[-1].system_fingerprint == "fp1"
|
||||
assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")]
|
||||
|
||||
def test_invoke_llm_with_structured_output_stream_no_id_events(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.parameter_rules = []
|
||||
model_schema.model = "gpt-4"
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
generator = invoke_llm_with_structured_output(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[],
|
||||
json_schema={},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserError):
|
||||
list(generator)
|
||||
|
||||
def test_parse_structured_output_empty_string(self):
|
||||
with pytest.raises(OutputParserError):
|
||||
_parse_structured_output("")
|
||||
@@ -1,589 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
||||
|
||||
class TestLLMGenerator:
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
|
||||
instance = MagicMock()
|
||||
mock_manager.return_value.get_default_model_instance.return_value = instance
|
||||
mock_manager.return_value.get_model_instance.return_value = instance
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def model_config_entity(self):
|
||||
return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7})
|
||||
|
||||
def test_generate_conversation_name_success(self, mock_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"})
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace:
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||
assert name == "Test Conversation Name"
|
||||
mock_trace.assert_called_once()
|
||||
|
||||
def test_generate_conversation_name_truncated(self, mock_model_instance):
|
||||
long_query = "a" * 2100
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"})
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", long_query)
|
||||
assert name == "Short Name"
|
||||
|
||||
def test_generate_conversation_name_empty_answer(self, mock_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = ""
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||
assert name == ""
|
||||
|
||||
def test_generate_conversation_name_json_repair(self, mock_model_instance):
|
||||
mock_response = MagicMock()
|
||||
# Invalid JSON that json_repair can fix
|
||||
mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||
assert name == "Repaired Name"
|
||||
|
||||
def test_generate_conversation_name_not_dict_result(self, mock_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '["not a dict"]'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||
assert name == "test query"
|
||||
|
||||
def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"something": "else"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||
assert name == "test query"
|
||||
|
||||
def test_generate_conversation_name_long_output(self, mock_model_instance):
|
||||
long_output = "a" * 100
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output})
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||
assert len(name) == 78 # 75 + "..."
|
||||
assert name.endswith("...")
|
||||
|
||||
def test_generate_suggested_questions_after_answer_success(self, mock_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||
assert len(questions) == 2
|
||||
assert questions[0] == "Question 1?"
|
||||
|
||||
def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance):
|
||||
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
|
||||
mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed")
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||
assert questions == []
|
||||
|
||||
def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance):
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||
assert questions == []
|
||||
|
||||
def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance):
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||
assert questions == []
|
||||
|
||||
def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=True
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "Generated Prompt"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert result["prompt"] == "Generated Prompt"
|
||||
assert result["error"] == ""
|
||||
|
||||
def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=True
|
||||
)
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert "Failed to generate rule config" in result["error"]
|
||||
|
||||
def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=True
|
||||
)
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert "Failed to generate rule config" in result["error"]
|
||||
assert "Random error" in result["error"]
|
||||
|
||||
def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||
)
|
||||
# Mocking 3 calls for invoke_llm
|
||||
mock_res1 = MagicMock()
|
||||
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
|
||||
|
||||
mock_res2 = MagicMock()
|
||||
mock_res2.message.get_text_content.return_value = '"var1", "var2"'
|
||||
|
||||
mock_res3 = MagicMock()
|
||||
mock_res3.message.get_text_content.return_value = "Opening Statement"
|
||||
|
||||
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3]
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert result["prompt"] == "Step 1 Prompt"
|
||||
assert result["variables"] == ["var1", "var2"]
|
||||
assert result["opening_statement"] == "Opening Statement"
|
||||
assert result["error"] == ""
|
||||
|
||||
def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||
)
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed")
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert "Failed to generate prefix prompt" in result["error"]
|
||||
|
||||
def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||
)
|
||||
mock_res1 = MagicMock()
|
||||
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
|
||||
|
||||
# Step 2 fails
|
||||
mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()]
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert "Failed to generate variables" in result["error"]
|
||||
|
||||
def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||
)
|
||||
mock_res1 = MagicMock()
|
||||
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
|
||||
|
||||
mock_res2 = MagicMock()
|
||||
mock_res2.message.get_text_content.return_value = '"var1"'
|
||||
|
||||
# Step 3 fails
|
||||
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")]
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert "Failed to generate conversation opener" in result["error"]
|
||||
|
||||
def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleGeneratePayload(
|
||||
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||
)
|
||||
# Mock any step to throw Exception
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error")
|
||||
|
||||
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||
assert "Failed to handle unexpected exception" in result["error"]
|
||||
assert "Unexpected multi-step error" in result["error"]
|
||||
|
||||
def test_generate_code_python_success(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleCodeGeneratePayload(
|
||||
instruction="print hello", code_language="python", model_config=model_config_entity
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "print('hello')"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||
assert result["code"] == "print('hello')"
|
||||
assert result["language"] == "python"
|
||||
|
||||
def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleCodeGeneratePayload(
|
||||
instruction="console log hello", code_language="javascript", model_config=model_config_entity
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "console.log('hello')"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||
assert result["code"] == "console.log('hello')"
|
||||
assert result["language"] == "javascript"
|
||||
|
||||
def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||
|
||||
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||
assert "Failed to generate code" in result["error"]
|
||||
|
||||
def test_generate_code_exception(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_generate_qa_document_success(self, mock_model_instance):
|
||||
mock_response = MagicMock(spec=LLMResult)
|
||||
mock_response.message = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "QA Document Content"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_qa_document("tenant_id", "query", "English")
|
||||
assert result == "QA Document Content"
|
||||
|
||||
def test_generate_qa_document_type_error(self, mock_model_instance):
|
||||
mock_model_instance.invoke_llm.return_value = "Not an LLMResult"
|
||||
|
||||
with pytest.raises(TypeError, match="Expected LLMResult when stream=False"):
|
||||
LLMGenerator.generate_qa_document("tenant_id", "query", "English")
|
||||
|
||||
def test_generate_structured_output_success(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||
parsed_output = json.loads(result["output"])
|
||||
assert parsed_output["type"] == "object"
|
||||
assert result["error"] == ""
|
||||
|
||||
def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "{'type': 'object'}"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||
parsed_output = json.loads(result["output"])
|
||||
assert parsed_output["type"] == "object"
|
||||
|
||||
def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "true" # parsed as bool
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
assert "Failed to parse structured output" in result["error"]
|
||||
|
||||
def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||
|
||||
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||
assert "Failed to generate JSON Schema" in result["error"]
|
||||
|
||||
def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity):
|
||||
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
# Mock __instruction_modify_common call via invoke_llm
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt"}
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
last_run = MagicMock()
|
||||
last_run.query = "q"
|
||||
last_run.answer = "a"
|
||||
last_run.error = "e"
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt"}
|
||||
|
||||
def test_instruction_modify_workflow_app_not_found(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found."):
|
||||
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
|
||||
|
||||
def test_instruction_modify_workflow_no_workflow(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = None
|
||||
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
|
||||
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service)
|
||||
|
||||
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
|
||||
last_run = MagicMock()
|
||||
last_run.node_type = "llm"
|
||||
last_run.status = "s"
|
||||
last_run.error = "e"
|
||||
# Return regular values, not Mocks
|
||||
last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]}
|
||||
last_run.load_full_inputs.return_value = {"in": "val"}
|
||||
|
||||
workflow_service.get_node_last_run.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "workflow"}
|
||||
|
||||
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
workflow_service.get_node_last_run.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "fallback"}
|
||||
|
||||
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
# Cause exception in node_type logic
|
||||
workflow.graph_dict = {"graph": {"nodes": []}}
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
workflow_service.get_node_last_run.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "fallback"}
|
||||
|
||||
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
|
||||
last_run = MagicMock()
|
||||
last_run.node_type = "llm"
|
||||
last_run.status = "s"
|
||||
last_run.error = "e"
|
||||
# Return regular empty list, not a Mock
|
||||
last_run.execution_metadata_dict = {"agent_log": []}
|
||||
last_run.load_full_inputs.return_value = {}
|
||||
|
||||
workflow_service.get_node_last_run.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "workflow"}
|
||||
|
||||
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}"
|
||||
LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal"
|
||||
)
|
||||
|
||||
# Verify the call to invoke_llm contains replaced instruction
|
||||
args, kwargs = mock_model_instance.invoke_llm.call_args
|
||||
prompt_messages = kwargs["prompt_messages"]
|
||||
user_msg = prompt_messages[1].content
|
||||
user_msg_dict = json.loads(user_msg)
|
||||
assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc.
|
||||
assert "current_val" in user_msg_dict["instruction"]
|
||||
|
||||
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No braces here"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
assert "Could not find a valid JSON object" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
# The exception message is "Expected a JSON object, but got list"
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity):
|
||||
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
|
||||
instance = MagicMock()
|
||||
mock_manager.return_value.get_model_instance.return_value = instance
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||
instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
workflow_service.get_node_last_run.return_value = None
|
||||
|
||||
LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
|
||||
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert "Failed to generate code" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No JSON here"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
@@ -82,68 +82,6 @@ class TestTraceContextFilter:
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == "051581bf3bb55c45"
|
||||
|
||||
def test_otel_context_invalid_trace_id(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0
|
||||
mock_context.is_valid = True
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
# Use mocks for base context to ensure we can test the fallback
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("core.logging.filters.get_trace_id", return_value=""),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == ""
|
||||
|
||||
def test_otel_context_invalid_span_id(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0
|
||||
mock_context.is_valid = True
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == ""
|
||||
|
||||
def test_otel_context_span_none(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=None),
|
||||
mock.patch("core.logging.filters.get_trace_id", return_value=""),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == ""
|
||||
|
||||
def test_otel_context_exception(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Trigger exception in OTEL block
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception),
|
||||
mock.patch("core.logging.filters.get_trace_id", return_value=""),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == ""
|
||||
|
||||
|
||||
class TestIdentityContextFilter:
|
||||
def test_sets_empty_identity_without_request_context(self, log_record):
|
||||
@@ -176,119 +114,3 @@ class TestIdentityContextFilter:
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
|
||||
def test_sets_empty_identity_unauthenticated(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
mock_user = mock.MagicMock()
|
||||
mock_user.is_authenticated = False
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.user_id == ""
|
||||
|
||||
def test_sets_identity_for_account(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockAccount:
|
||||
pass
|
||||
|
||||
mock_user = MockAccount()
|
||||
mock_user.id = "account_id"
|
||||
mock_user.current_tenant_id = "tenant_id"
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.Account", MockAccount),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == "tenant_id"
|
||||
assert log_record.user_id == "account_id"
|
||||
assert log_record.user_type == "account"
|
||||
|
||||
def test_sets_identity_for_account_no_tenant(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockAccount:
|
||||
pass
|
||||
|
||||
mock_user = MockAccount()
|
||||
mock_user.id = "account_id"
|
||||
mock_user.current_tenant_id = None
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.Account", MockAccount),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == ""
|
||||
assert log_record.user_id == "account_id"
|
||||
assert log_record.user_type == "account"
|
||||
|
||||
def test_sets_identity_for_end_user(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockEndUser:
|
||||
pass
|
||||
|
||||
class AnotherClass:
|
||||
pass
|
||||
|
||||
mock_user = MockEndUser()
|
||||
mock_user.id = "end_user_id"
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
mock_user.type = "custom_type"
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.model.EndUser", MockEndUser),
|
||||
mock.patch("models.Account", AnotherClass),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == "tenant_id"
|
||||
assert log_record.user_id == "end_user_id"
|
||||
assert log_record.user_type == "custom_type"
|
||||
|
||||
def test_sets_identity_for_end_user_default_type(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockEndUser:
|
||||
pass
|
||||
|
||||
class AnotherClass:
|
||||
pass
|
||||
|
||||
mock_user = MockEndUser()
|
||||
mock_user.id = "end_user_id"
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
mock_user.type = None
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.model.EndUser", MockEndUser),
|
||||
mock.patch("models.Account", AnotherClass),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == "tenant_id"
|
||||
assert log_record.user_id == "end_user_id"
|
||||
assert log_record.user_type == "end_user"
|
||||
|
||||
@@ -1,39 +1,27 @@
|
||||
"""Unit tests for MCP OAuth authentication flow."""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.auth.auth_flow import (
|
||||
OAUTH_STATE_EXPIRY_SECONDS,
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX,
|
||||
OAuthCallbackState,
|
||||
_create_secure_redis_state,
|
||||
_parse_token_response,
|
||||
_retrieve_redis_state,
|
||||
auth,
|
||||
build_oauth_authorization_server_metadata_discovery_urls,
|
||||
build_protected_resource_metadata_discovery_urls,
|
||||
check_support_resource_discovery,
|
||||
client_credentials_flow,
|
||||
discover_oauth_authorization_server_metadata,
|
||||
discover_oauth_metadata,
|
||||
discover_protected_resource_metadata,
|
||||
exchange_authorization,
|
||||
generate_pkce_challenge,
|
||||
get_effective_scope,
|
||||
handle_callback,
|
||||
refresh_authorization,
|
||||
register_client,
|
||||
start_authorization,
|
||||
)
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
@@ -776,555 +764,3 @@ class TestAuthOrchestration:
|
||||
auth(mock_provider, authorization_code="auth-code")
|
||||
|
||||
assert "Existing OAuth client information is required" in str(exc_info.value)
|
||||
|
||||
def test_generate_pkce_challenge(self):
|
||||
verifier, challenge = generate_pkce_challenge()
|
||||
assert verifier
|
||||
assert challenge
|
||||
assert "=" not in verifier
|
||||
assert "=" not in challenge
|
||||
|
||||
def test_build_protected_resource_metadata_discovery_urls(self):
|
||||
# Case 1: WWW-Auth URL provided
|
||||
urls = build_protected_resource_metadata_discovery_urls(
|
||||
"https://auth.example.com/prm", "https://api.example.com"
|
||||
)
|
||||
assert "https://auth.example.com/prm" in urls
|
||||
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
|
||||
|
||||
# Case 2: No WWW-Auth URL, with path
|
||||
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1")
|
||||
assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls
|
||||
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
|
||||
|
||||
# Case 3: No path
|
||||
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com")
|
||||
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
|
||||
|
||||
def test_build_oauth_authorization_server_metadata_discovery_urls(self):
|
||||
# Case 1: with auth_server_url
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
"https://auth.example.com", "https://api.example.com"
|
||||
)
|
||||
assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls
|
||||
assert "https://auth.example.com/.well-known/openid-configuration" in urls
|
||||
|
||||
# Case 2: with path
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant")
|
||||
assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls
|
||||
assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_protected_resource_metadata(self, mock_get):
|
||||
# Success
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"resource": "https://api.example.com",
|
||||
"authorization_servers": ["https://auth"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is not None
|
||||
assert result.resource == "https://api.example.com"
|
||||
|
||||
# 404 then Success
|
||||
res404 = Mock()
|
||||
res404.status_code = 404
|
||||
mock_get.side_effect = [res404, mock_response]
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com/path")
|
||||
assert result is not None
|
||||
assert result.resource == "https://api.example.com"
|
||||
|
||||
# Error handling
|
||||
mock_get.side_effect = httpx.RequestError("Error")
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_authorization_server_metadata(self, mock_get):
|
||||
# Success
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://auth.example.com/auth",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is not None
|
||||
assert result.authorization_endpoint == "https://auth.example.com/auth"
|
||||
|
||||
# 404
|
||||
res404 = Mock()
|
||||
res404.status_code = 404
|
||||
mock_get.side_effect = [res404, mock_response]
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant")
|
||||
assert result is not None
|
||||
assert result.authorization_endpoint == "https://auth.example.com/auth"
|
||||
|
||||
# ValidationError
|
||||
mock_response.json.return_value = {"invalid": "data"}
|
||||
mock_get.side_effect = None
|
||||
mock_get.return_value = mock_response
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
def test_get_effective_scope(self):
|
||||
prm = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
authorization_servers=["https://auth"],
|
||||
scopes_supported=["read", "write"],
|
||||
)
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/auth",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
scopes_supported=["openid", "profile"],
|
||||
)
|
||||
|
||||
# 1. WWW-Auth priority
|
||||
assert get_effective_scope("scope1", prm, asm, "client") == "scope1"
|
||||
# 2. PRM priority
|
||||
assert get_effective_scope(None, prm, asm, "client") == "read write"
|
||||
# 3. ASM priority
|
||||
assert get_effective_scope(None, None, asm, "client") == "openid profile"
|
||||
# 4. Client configured
|
||||
assert get_effective_scope(None, None, None, "client") == "client"
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||
def test_redis_state_management(self, mock_redis):
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id="p1",
|
||||
tenant_id="t1",
|
||||
server_url="https://api",
|
||||
metadata=None,
|
||||
client_information=OAuthClientInformation(client_id="c1"),
|
||||
code_verifier="cv",
|
||||
redirect_uri="https://re",
|
||||
)
|
||||
|
||||
# Create
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
assert state_key
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
# Retrieve Success
|
||||
mock_redis.get.return_value = state_data.model_dump_json()
|
||||
retrieved = _retrieve_redis_state(state_key)
|
||||
assert retrieved.provider_id == "p1"
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
# Retrieve Failure - Not found
|
||||
mock_redis.get.return_value = None
|
||||
with pytest.raises(ValueError, match="expired or does not exist"):
|
||||
_retrieve_redis_state("absent")
|
||||
|
||||
# Retrieve Failure - Invalid JSON
|
||||
mock_redis.get.return_value = "invalid"
|
||||
with pytest.raises(ValueError, match="Invalid state parameter"):
|
||||
_retrieve_redis_state("invalid")
|
||||
|
||||
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||
@patch("core.mcp.auth.auth_flow.exchange_authorization")
|
||||
def test_handle_callback(self, mock_exchange, mock_retrieve):
|
||||
state = Mock(spec=OAuthCallbackState)
|
||||
state.server_url = "https://api"
|
||||
state.metadata = None
|
||||
state.client_information = Mock()
|
||||
state.code_verifier = "cv"
|
||||
state.redirect_uri = "https://re"
|
||||
mock_retrieve.return_value = state
|
||||
|
||||
tokens = Mock(spec=OAuthTokens)
|
||||
mock_exchange.return_value = tokens
|
||||
|
||||
s, t = handle_callback("key", "code")
|
||||
assert s == state
|
||||
assert t == tokens
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery(self, mock_get):
|
||||
# Case 1: authorization_servers (plural)
|
||||
res = Mock()
|
||||
res.status_code = 200
|
||||
res.json.return_value = {"authorization_servers": ["https://auth1"]}
|
||||
mock_get.return_value = res
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is True
|
||||
assert url == "https://auth1"
|
||||
|
||||
# Case 2: authorization_server_url (singular alias)
|
||||
res.json.return_value = {"authorization_server_url": ["https://auth2"]}
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is True
|
||||
assert url == "https://auth2"
|
||||
|
||||
# Case 3: Missing fields
|
||||
res.json.return_value = {"nothing": []}
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
# Case 4: 404
|
||||
res.status_code = 404
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
# Case 5: RequestError
|
||||
mock_get.side_effect = httpx.RequestError("Error")
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
def test_discover_oauth_metadata(self):
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
mock_prm.return_value = ProtectedResourceMetadata(
|
||||
resource="https://api", authorization_servers=["https://auth"]
|
||||
)
|
||||
mock_asm.return_value = Mock(spec=OAuthMetadata)
|
||||
|
||||
asm, prm, hint = discover_oauth_metadata("https://api")
|
||||
assert asm == mock_asm.return_value
|
||||
assert prm == mock_prm.return_value
|
||||
mock_asm.assert_called_with("https://auth", "https://api", None)
|
||||
|
||||
def test_start_authorization(self):
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/authorize",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
client_info = OAuthClientInformation(client_id="c1")
|
||||
|
||||
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create:
|
||||
mock_create.return_value = "state-key"
|
||||
|
||||
# Success with scope
|
||||
url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read")
|
||||
assert "scope=read" in url
|
||||
assert "state=state-key" in url
|
||||
|
||||
# Success without metadata
|
||||
url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1")
|
||||
assert "https://api/authorize" in url
|
||||
|
||||
# Failure: incompatible auth server
|
||||
metadata.response_types_supported = ["implicit"]
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1")
|
||||
|
||||
def test_parse_token_response(self):
|
||||
# Case 1: JSON
|
||||
res = Mock()
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at"
|
||||
|
||||
# Case 2: Form-urlencoded
|
||||
res.headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
res.text = "access_token=at2&token_type=Bearer"
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at2"
|
||||
|
||||
# Case 3: No content-type, but JSON
|
||||
res.headers = {}
|
||||
res.json.return_value = {"access_token": "at3", "token_type": "Bearer"}
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at3"
|
||||
|
||||
# Case 4: No content-type, not JSON, but Form
|
||||
res.json.side_effect = json.JSONDecodeError("msg", "doc", 0)
|
||||
res.text = "access_token=at4&token_type=Bearer"
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at4"
|
||||
|
||||
# Case 5: Validation Error fallback
|
||||
res.json.side_effect = ValidationError.from_exception_data("error", [])
|
||||
res.text = "access_token=at5&token_type=Bearer"
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at5"
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_exchange_authorization(self, mock_post):
|
||||
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/authorize",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
|
||||
# Success
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
|
||||
mock_post.return_value = res
|
||||
|
||||
tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
|
||||
assert tokens.access_token == "at"
|
||||
|
||||
# Failure: Unsupported grant type
|
||||
metadata.grant_types_supported = ["client_credentials"]
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
|
||||
|
||||
# Failure: HTTP error
|
||||
metadata.grant_types_supported = ["authorization_code"]
|
||||
res.is_success = False
|
||||
res.status_code = 400
|
||||
with pytest.raises(ValueError, match="Token exchange failed"):
|
||||
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_refresh_authorization(self, mock_post):
|
||||
# Case 1: with client_secret
|
||||
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
|
||||
# Success
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"}
|
||||
mock_post.return_value = res
|
||||
|
||||
tokens = refresh_authorization("https://api", None, client_info, "rt")
|
||||
assert tokens.access_token == "at_new"
|
||||
assert mock_post.call_args[1]["data"]["client_secret"] == "s1"
|
||||
|
||||
# Failure: MaxRetriesExceededError
|
||||
mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries")
|
||||
with pytest.raises(MCPRefreshTokenError):
|
||||
refresh_authorization("https://api", None, client_info, "rt")
|
||||
|
||||
# Failure: HTTP error
|
||||
mock_post.side_effect = None
|
||||
res.is_success = False
|
||||
res.text = "error_msg"
|
||||
with pytest.raises(MCPRefreshTokenError, match="error_msg"):
|
||||
refresh_authorization("https://api", None, client_info, "rt")
|
||||
|
||||
# Failure: Incompatible metadata
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
refresh_authorization("https://api", metadata, client_info, "rt")
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_client_credentials_flow(self, mock_post):
|
||||
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
|
||||
# Success with secret
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"}
|
||||
mock_post.return_value = res
|
||||
|
||||
tokens = client_credentials_flow("https://api", None, client_info, "read")
|
||||
assert tokens.access_token == "at_cc"
|
||||
args, kwargs = mock_post.call_args
|
||||
assert "Authorization" in kwargs["headers"]
|
||||
|
||||
# Success without secret
|
||||
client_info_no_secret = OAuthClientInformation(client_id="c2")
|
||||
tokens = client_credentials_flow("https://api", None, client_info_no_secret)
|
||||
args, kwargs = mock_post.call_args
|
||||
assert kwargs["data"]["client_id"] == "c2"
|
||||
|
||||
# Failure: Incompatible metadata
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
client_credentials_flow("https://api", metadata, client_info)
|
||||
|
||||
# Failure: HTTP error
|
||||
res.is_success = False
|
||||
res.status_code = 401
|
||||
res.text = "Unauthorized"
|
||||
with pytest.raises(ValueError, match="Client credentials token request failed"):
|
||||
client_credentials_flow("https://api", None, client_info)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_register_client(self, mock_post):
|
||||
# Case 1: Success with metadata
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
registration_endpoint="https://auth/register",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"])
|
||||
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.json.return_value = {
|
||||
"client_id": "c_new",
|
||||
"client_secret": "s_new",
|
||||
"client_name": "Dify",
|
||||
"redirect_uris": ["https://re"],
|
||||
}
|
||||
mock_post.return_value = res
|
||||
|
||||
info = register_client("https://api", metadata, client_metadata)
|
||||
assert info.client_id == "c_new"
|
||||
|
||||
# Case 2: Success without metadata
|
||||
info = register_client("https://api", None, client_metadata)
|
||||
assert mock_post.call_args[0][0] == "https://api/register"
|
||||
|
||||
# Case 3: Metadata provided but no endpoint
|
||||
metadata.registration_endpoint = None
|
||||
with pytest.raises(ValueError, match="does not support dynamic client registration"):
|
||||
register_client("https://api", metadata, client_metadata)
|
||||
|
||||
# Failure: HTTP
|
||||
res.is_success = False
|
||||
res.raise_for_status = Mock()
|
||||
res.status_code = 400
|
||||
# If is_success is false, it should call raise_for_status
|
||||
register_client("https://api", None, client_metadata)
|
||||
res.raise_for_status.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_failures(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
|
||||
# Case 1: No server metadata
|
||||
mock_discover.return_value = (None, None, None)
|
||||
with pytest.raises(ValueError, match="Failed to discover OAuth metadata"):
|
||||
auth(provider)
|
||||
|
||||
# Case 2: No client info, exchange code provided
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
provider.retrieve_client_information.return_value = None
|
||||
with pytest.raises(ValueError, match="Existing OAuth client information is required"):
|
||||
auth(provider, authorization_code="code")
|
||||
|
||||
# Case 3: CLIENT_CREDENTIALS but client must provide info
|
||||
asm.grant_types_supported = ["client_credentials"]
|
||||
with pytest.raises(ValueError, match="requires client_id and client_secret"):
|
||||
auth(provider)
|
||||
|
||||
# Case 4: Client registration fails
|
||||
asm.grant_types_supported = ["authorization_code"]
|
||||
with patch("core.mcp.auth.auth_flow.register_client") as mock_reg:
|
||||
mock_reg.side_effect = httpx.RequestError("Reg failed")
|
||||
with pytest.raises(ValueError, match="Could not register OAuth client"):
|
||||
auth(provider)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_client_credentials(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
provider.decrypt_credentials.return_value = {"scope": "read"}
|
||||
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["client_credentials"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc:
|
||||
mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer")
|
||||
|
||||
result = auth(provider)
|
||||
assert result.response == {"result": "success"}
|
||||
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||
assert result.actions[0].data["grant_type"] == "client_credentials"
|
||||
|
||||
# Failure in CC flow
|
||||
mock_cc.side_effect = ValueError("CC Failed")
|
||||
with pytest.raises(ValueError, match="Client credentials flow failed"):
|
||||
auth(provider)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_authorization_code(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
|
||||
provider.decrypt_credentials.return_value = {}
|
||||
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
|
||||
# Case 1: Exchange code
|
||||
with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve:
|
||||
state = Mock(spec=OAuthCallbackState)
|
||||
state.code_verifier = "cv"
|
||||
state.redirect_uri = "https://re"
|
||||
mock_retrieve.return_value = state
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange:
|
||||
mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer")
|
||||
|
||||
# Success
|
||||
result = auth(provider, authorization_code="code", state_param="sp")
|
||||
assert result.response == {"result": "success"}
|
||||
|
||||
# Missing state_param
|
||||
with pytest.raises(ValueError, match="State parameter is required"):
|
||||
auth(provider, authorization_code="code")
|
||||
|
||||
# Missing verifier in state
|
||||
state.code_verifier = None
|
||||
with pytest.raises(ValueError, match="Missing code_verifier"):
|
||||
auth(provider, authorization_code="code", state_param="sp")
|
||||
|
||||
# Invalid state
|
||||
mock_retrieve.side_effect = ValueError("Invalid")
|
||||
with pytest.raises(ValueError, match="Invalid state parameter"):
|
||||
auth(provider, authorization_code="code", state_param="sp")
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_refresh_failure(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
|
||||
provider.decrypt_credentials.return_value = {}
|
||||
provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt")
|
||||
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh:
|
||||
mock_refresh.side_effect = ValueError("Refresh Failed")
|
||||
with pytest.raises(ValueError, match="Could not refresh OAuth tokens"):
|
||||
auth(provider)
|
||||
|
||||
@@ -322,475 +322,3 @@ def test_sse_client_concurrent_access():
|
||||
assert len(received_messages) == 10
|
||||
for i in range(10):
|
||||
assert f"message_{i}" in received_messages
|
||||
|
||||
|
||||
class TestStatusClasses:
|
||||
"""Tests for _StatusReady and _StatusError data containers."""
|
||||
|
||||
def test_status_ready_stores_endpoint(self):
|
||||
from core.mcp.client.sse_client import _StatusReady
|
||||
|
||||
status = _StatusReady("http://example.com/messages/")
|
||||
assert status.endpoint_url == "http://example.com/messages/"
|
||||
|
||||
def test_status_error_stores_exception(self):
|
||||
from core.mcp.client.sse_client import _StatusError
|
||||
|
||||
exc = ValueError("bad endpoint")
|
||||
status = _StatusError(exc)
|
||||
assert status.exc is exc
|
||||
|
||||
|
||||
class TestSSETransportInit:
|
||||
"""Tests for SSETransport default and explicit init values."""
|
||||
|
||||
def test_defaults(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
t = SSETransport("http://example.com/sse")
|
||||
assert t.url == "http://example.com/sse"
|
||||
assert t.headers == {}
|
||||
assert t.timeout == 5.0
|
||||
assert t.sse_read_timeout == 60.0
|
||||
assert t.endpoint_url is None
|
||||
assert t.event_source is None
|
||||
|
||||
def test_explicit_headers_not_mutated(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
hdrs = {"X-Foo": "bar"}
|
||||
t = SSETransport("http://example.com/sse", headers=hdrs)
|
||||
assert t.headers is hdrs
|
||||
|
||||
|
||||
class TestHandleEndpointEvent:
|
||||
"""Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch."""
|
||||
|
||||
def test_invalid_origin_puts_status_error(self):
|
||||
from core.mcp.client.sse_client import SSETransport, _StatusError
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Provide a full URL with a different origin so urljoin keeps it as-is
|
||||
transport._handle_endpoint_event("http://evil.com/messages/", status_queue)
|
||||
|
||||
result = status_queue.get_nowait()
|
||||
assert isinstance(result, _StatusError)
|
||||
assert "does not match" in str(result.exc)
|
||||
|
||||
def test_valid_origin_puts_status_ready(self):
|
||||
from core.mcp.client.sse_client import SSETransport, _StatusReady
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
transport._handle_endpoint_event("/messages/?session_id=abc", status_queue)
|
||||
|
||||
result = status_queue.get_nowait()
|
||||
assert isinstance(result, _StatusReady)
|
||||
assert "example.com" in result.endpoint_url
|
||||
|
||||
|
||||
class TestHandleSSEEvent:
|
||||
"""Tests for SSETransport._handle_sse_event covering all match branches."""
|
||||
|
||||
def _make_sse(self, event_type: str, data: str):
|
||||
sse = Mock()
|
||||
sse.event = event_type
|
||||
sse.data = data
|
||||
return sse
|
||||
|
||||
def test_message_event_dispatched(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue)
|
||||
|
||||
item = read_queue.get_nowait()
|
||||
assert hasattr(item, "message")
|
||||
|
||||
def test_unknown_event_logs_warning_and_does_nothing(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue)
|
||||
|
||||
assert read_queue.empty()
|
||||
assert status_queue.empty()
|
||||
|
||||
|
||||
class TestSSEReader:
|
||||
"""Tests for SSETransport.sse_reader exception branches."""
|
||||
|
||||
def test_read_error_closes_cleanly(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
event_source = Mock()
|
||||
event_source.iter_sse.side_effect = httpx.ReadError("connection reset")
|
||||
|
||||
transport.sse_reader(event_source, read_queue, status_queue)
|
||||
|
||||
# Finally block always puts None as sentinel
|
||||
sentinel = read_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_generic_exception_puts_exc_then_none(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
boom = RuntimeError("unexpected!")
|
||||
event_source = Mock()
|
||||
event_source.iter_sse.side_effect = boom
|
||||
|
||||
transport.sse_reader(event_source, read_queue, status_queue)
|
||||
|
||||
exc_item = read_queue.get_nowait()
|
||||
assert exc_item is boom
|
||||
|
||||
sentinel = read_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
|
||||
class TestSendMessage:
|
||||
"""Tests for SSETransport._send_message."""
|
||||
|
||||
def _make_session_message(self):
|
||||
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(msg_json)
|
||||
return types.SessionMessage(msg)
|
||||
|
||||
def test_sends_post_and_raises_for_status(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_client = Mock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
transport._send_message(mock_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
mock_client.post.assert_called_once()
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
class TestPostWriter:
|
||||
"""Tests for SSETransport.post_writer exception branches."""
|
||||
|
||||
def _make_session_message(self):
|
||||
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(msg_json)
|
||||
return types.SessionMessage(msg)
|
||||
|
||||
def test_none_message_exits_loop(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
write_queue.put(None) # Signal shutdown immediately
|
||||
|
||||
mock_client = Mock()
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# Should put final None sentinel
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_exception_in_message_put_back_to_queue(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
exc = ValueError("some error")
|
||||
write_queue.put(exc) # Exception goes in first
|
||||
write_queue.put(None) # Then shutdown signal
|
||||
|
||||
mock_client = Mock()
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# The exception should be re-queued, then None from loop exit, then None from finally
|
||||
item1 = write_queue.get_nowait()
|
||||
assert isinstance(item1, Exception)
|
||||
|
||||
def test_read_error_shuts_down_cleanly(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
write_queue.put(session_msg)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_client = Mock()
|
||||
mock_client.post.side_effect = httpx.ReadError("connection dropped")
|
||||
|
||||
# post_writer calls _send_message which calls client.post → ReadError propagates
|
||||
# The ReadError is raised inside _send_message → propagates out of the while loop
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# finally always puts None
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_generic_exception_puts_exc_in_queue(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
write_queue.put(session_msg)
|
||||
|
||||
mock_client = Mock()
|
||||
boom = RuntimeError("boom")
|
||||
mock_client.post.side_effect = boom
|
||||
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
exc_item = write_queue.get_nowait()
|
||||
assert isinstance(exc_item, Exception)
|
||||
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_queue_empty_timeout_continues_loop(self):
|
||||
"""Cover the 'except queue.Empty: continue' branch (line 188) in post_writer."""
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
mock_client = Mock()
|
||||
|
||||
# Patch queue.Queue.get so it raises Empty first, then returns None (shutdown)
|
||||
call_count = {"n": 0}
|
||||
original_get = write_queue.get
|
||||
|
||||
def patched_get(*args, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise queue.Empty
|
||||
|
||||
write_queue.get = patched_get # type: ignore[method-assign]
|
||||
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# finally always puts None sentinel
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries)
|
||||
|
||||
|
||||
class TestWaitForEndpoint:
|
||||
"""Tests for SSETransport._wait_for_endpoint edge cases."""
|
||||
|
||||
def test_raises_on_empty_queue(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue() # empty
|
||||
|
||||
with pytest.raises(ValueError, match="failed to get endpoint URL"):
|
||||
transport._wait_for_endpoint(status_queue)
|
||||
|
||||
def test_raises_status_error_exception(self):
|
||||
from core.mcp.client.sse_client import SSETransport, _StatusError
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
exc = ValueError("malicious endpoint")
|
||||
status_queue.put(_StatusError(exc))
|
||||
|
||||
with pytest.raises(ValueError, match="malicious endpoint"):
|
||||
transport._wait_for_endpoint(status_queue)
|
||||
|
||||
def test_raises_on_unknown_status_type(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Put an object that is neither _StatusReady nor _StatusError
|
||||
status_queue.put("unexpected_value")
|
||||
|
||||
with pytest.raises(ValueError, match="failed to get endpoint URL"):
|
||||
transport._wait_for_endpoint(status_queue)
|
||||
|
||||
|
||||
class TestSSEClientRuntimeError:
|
||||
"""Test sse_client context manager handles RuntimeError on close()."""
|
||||
|
||||
def test_runtime_error_on_close_is_suppressed(self):
|
||||
"""Ensure RuntimeError raised by event_source.response.close() is caught."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
class MockSSEEvent:
|
||||
def __init__(self, event_type: str, data: str):
|
||||
self.event = event_type
|
||||
self.data = data
|
||||
|
||||
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc:
|
||||
mock_client = Mock()
|
||||
mock_cf.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_es = Mock()
|
||||
mock_es.response.raise_for_status.return_value = None
|
||||
mock_es.iter_sse.return_value = [endpoint_event]
|
||||
# Make close() raise RuntimeError to exercise line 307-308
|
||||
mock_es.response.close.side_effect = RuntimeError("already closed")
|
||||
mock_sc.return_value.__enter__.return_value = mock_es
|
||||
|
||||
# Should NOT raise even though close() raises RuntimeError
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url) as (rq, wq):
|
||||
pass
|
||||
|
||||
|
||||
class TestStandaloneSendMessage:
|
||||
"""Tests for the module-level send_message() function."""
|
||||
|
||||
def _make_session_message(self):
|
||||
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(msg_json)
|
||||
return types.SessionMessage(msg)
|
||||
|
||||
def test_send_message_success(self):
|
||||
from core.mcp.client.sse_client import send_message
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_http_client = Mock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
send_message(mock_http_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
def test_send_message_raises_on_http_error(self):
|
||||
from core.mcp.client.sse_client import send_message
|
||||
|
||||
mock_http_client = Mock()
|
||||
mock_http_client.post.side_effect = httpx.ConnectError("refused")
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
send_message(mock_http_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
def test_send_message_raises_for_status_failure(self):
|
||||
from core.mcp.client.sse_client import send_message
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=Mock(), response=Mock(status_code=404)
|
||||
)
|
||||
mock_http_client = Mock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
send_message(mock_http_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
|
||||
class TestReadMessages:
|
||||
"""Tests for the module-level read_messages() generator."""
|
||||
|
||||
def _make_mock_sse_event(self, event_type: str, data: str):
|
||||
ev = Mock()
|
||||
ev.event = event_type
|
||||
ev.data = data
|
||||
return ev
|
||||
|
||||
def test_valid_message_event_yields_session_message(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
mock_sse_event = self._make_mock_sse_event("message", valid_json)
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = [mock_sse_event]
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
assert len(results) == 1
|
||||
assert hasattr(results[0], "message")
|
||||
|
||||
def test_invalid_json_yields_exception(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
mock_sse_event = self._make_mock_sse_event("message", "{not valid json}")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = [mock_sse_event]
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], Exception)
|
||||
|
||||
def test_non_message_event_is_skipped(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = [mock_sse_event]
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
# Non-message events produce no output
|
||||
assert results == []
|
||||
|
||||
def test_outer_exception_yields_exc(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
boom = RuntimeError("stream broken")
|
||||
mock_client = Mock()
|
||||
mock_client.events.side_effect = boom
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
assert len(results) == 1
|
||||
assert results[0] is boom
|
||||
|
||||
def test_multiple_events_mixed(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}'
|
||||
events = [
|
||||
self._make_mock_sse_event("endpoint", "/messages/"),
|
||||
self._make_mock_sse_event("message", valid_json),
|
||||
self._make_mock_sse_event("message", "{bad json}"),
|
||||
]
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = events
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
# endpoint is skipped; 1 valid SessionMessage + 1 Exception
|
||||
assert len(results) == 2
|
||||
assert hasattr(results[0], "message")
|
||||
assert isinstance(results[1], Exception)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,617 +0,0 @@
|
||||
import queue
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import HTTPStatusError, Request, Response
|
||||
from pydantic import BaseModel, ConfigDict, RootModel
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
from core.mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCResponse,
|
||||
Notification,
|
||||
RequestParams,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.types import (
|
||||
Request as MCPRequest,
|
||||
)
|
||||
|
||||
|
||||
class MockRequestParams(RequestParams):
|
||||
name: str = "default"
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MockRequest(MCPRequest[MockRequestParams, str]):
|
||||
method: str = "test/request"
|
||||
params: MockRequestParams = MockRequestParams()
|
||||
|
||||
|
||||
class MockResult(BaseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class MockNotificationParams(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class MockNotification(Notification[MockNotificationParams, str]):
|
||||
method: str = "test/notification"
|
||||
params: MockNotificationParams
|
||||
|
||||
|
||||
class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]):
|
||||
pass
|
||||
|
||||
|
||||
class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]):
|
||||
pass
|
||||
|
||||
|
||||
class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.received_requests = []
|
||||
self.received_notifications = []
|
||||
self.handled_incoming = []
|
||||
|
||||
def _received_request(self, responder):
|
||||
self.received_requests.append(responder)
|
||||
|
||||
def _received_notification(self, notification):
|
||||
self.received_notifications.append(notification)
|
||||
|
||||
def _handle_incoming(self, item):
|
||||
self.handled_incoming.append(item)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streams():
|
||||
return queue.Queue(), queue.Queue()
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_request_responder_respond(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
on_complete = MagicMock()
|
||||
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
|
||||
responder.respond(MockResult(result="ok"))
|
||||
|
||||
with responder as r:
|
||||
r.respond(MockResult(result="ok"))
|
||||
with pytest.raises(AssertionError, match="Request already responded to"):
|
||||
r.respond(MockResult(result="error"))
|
||||
|
||||
assert responder.completed is True
|
||||
on_complete.assert_called_once_with(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert isinstance(msg.message.root, JSONRPCResponse)
|
||||
assert msg.message.root.result == {"result": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_request_responder_cancel(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
on_complete = MagicMock()
|
||||
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
|
||||
responder.cancel()
|
||||
|
||||
with responder as r:
|
||||
r.cancel()
|
||||
|
||||
assert responder.completed is True
|
||||
on_complete.assert_called_once_with(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert isinstance(msg.message.root, JSONRPCError)
|
||||
assert msg.message.root.error.message == "Request cancelled"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_base_session_lifecycle(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session as s:
|
||||
assert isinstance(s, MockSession)
|
||||
assert s._executor is not None
|
||||
assert s._receiver_future is not None
|
||||
|
||||
session._receiver_future.result(timeout=5.0)
|
||||
assert session._receiver_future.done()
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_success(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_response():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"})
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_response, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.send_request(request, MockResult)
|
||||
assert result.result == "hello world"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_retry_loop_coverage(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_delayed_response():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
time.sleep(0.2)
|
||||
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"})
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_delayed_response, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1))
|
||||
assert result.result == "slow"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_jsonrpc_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error"))
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPConnectionError) as exc:
|
||||
session.send_request(request, MockResult)
|
||||
assert exc.value.args[0].message == "Error"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_auth_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized"))
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPAuthError):
|
||||
session.send_request(request, MockResult)
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_http_status_error_coverage(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_direct_http_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
# To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError
|
||||
# DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError.
|
||||
response = Response(status_code=403, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Forbidden", request=response.request, response=response)
|
||||
session._response_streams[req_id].put(error)
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_direct_http_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
# We still need the session for request ID generation and queue setup
|
||||
with session:
|
||||
with pytest.raises(MCPConnectionError) as exc:
|
||||
session.send_request(request, MockResult)
|
||||
assert exc.value.args[0].code == 403
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_http_status_auth_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
response = Response(status_code=401, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPAuthError):
|
||||
session.send_request(request, MockResult)
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
notification = MockNotification(method="notify", params=MockNotificationParams(message="hi"))
|
||||
|
||||
session.send_notification(notification, related_request_id="rel-1")
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert isinstance(msg.message.root, JSONRPCNotification)
|
||||
assert msg.message.root.method == "notify"
|
||||
assert msg.message.root.params == {"message": "hi"}
|
||||
assert msg.metadata.related_request_id == "rel-1"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_request(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if session.received_requests:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(session.received_requests) == 1
|
||||
responder = session.received_requests[0]
|
||||
assert responder.request_id == 1
|
||||
assert responder.request.root.method == "test/request"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if session.received_notifications:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(session.received_notifications) == 1
|
||||
assert isinstance(session.received_notifications[0].root, MockNotification)
|
||||
assert session.received_notifications[0].root.method == "test/notification"
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_receive_loop_cancel_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification)
|
||||
|
||||
with session:
|
||||
req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if "req-1" in session._in_flight:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert "req-1" in session._in_flight
|
||||
responder = session._in_flight["req-1"]
|
||||
|
||||
with responder:
|
||||
cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if responder.completed:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert responder.completed is True
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert isinstance(msg.message.root, JSONRPCError)
|
||||
assert msg.message.root.id == "req-1"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_exception(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
read_stream.put(Exception("Unexpected error"))
|
||||
for _ in range(30):
|
||||
if any(isinstance(x, Exception) for x in session.handled_incoming):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_http_status_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
session._request_id = 1
|
||||
resp_queue = queue.Queue()
|
||||
session._response_streams[0] = resp_queue
|
||||
|
||||
response = Response(status_code=401, request=Request("GET", "http://test"))
|
||||
# Using 401 specifically as _receive_loop preserves it
|
||||
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
|
||||
got = resp_queue.get(timeout=2)
|
||||
assert isinstance(got, HTTPStatusError)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_http_status_error_non_401(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
session._request_id = 1
|
||||
resp_queue = queue.Queue()
|
||||
session._response_streams[0] = resp_queue
|
||||
|
||||
response = Response(status_code=500, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Server Error", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
|
||||
got = resp_queue.get(timeout=2)
|
||||
assert isinstance(got, JSONRPCError)
|
||||
assert got.error.code == 500
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_check_receiver_status_fail(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def raise_err():
|
||||
raise RuntimeError("Receiver failed")
|
||||
|
||||
future = executor.submit(raise_err)
|
||||
session._receiver_future = future
|
||||
|
||||
try:
|
||||
future.result()
|
||||
except:
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match="Receiver failed"):
|
||||
session.check_receiver_status()
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_unknown_request_id(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True})
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(resp)))
|
||||
|
||||
for _ in range(30):
|
||||
if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert any("Server Error" in str(x) for x in session.handled_incoming)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_http_error_unknown_id(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
response = Response(status_code=401, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
|
||||
for _ in range(30):
|
||||
if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert any("unknown request ID" in str(x) for x in session.handled_incoming)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_validation_error_notification(streams):
|
||||
from core.mcp.session.base_session import logger
|
||||
|
||||
with patch.object(logger, "warning") as mock_warning:
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification])
|
||||
|
||||
with session:
|
||||
notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
|
||||
time.sleep(1.0)
|
||||
|
||||
assert mock_warning.called
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_none_response(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_none():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
session._response_streams[req_id].put(None)
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_none, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPConnectionError) as exc:
|
||||
session.send_request(request, MockResult)
|
||||
assert exc.value.args[0].message == "No response received"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_session_exit_timeout(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
mock_future = MagicMock(spec=Future)
|
||||
mock_future.result.side_effect = TimeoutError()
|
||||
mock_future.done.return_value = False
|
||||
|
||||
session._receiver_future = mock_future
|
||||
session._executor = MagicMock(spec=ThreadPoolExecutor)
|
||||
|
||||
session.__exit__(None, None, None)
|
||||
|
||||
mock_future.cancel.assert_called_once()
|
||||
session._executor.shutdown.assert_called_once_with(wait=False)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_fatal_exception(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")):
|
||||
with patch("core.mcp.session.base_session.logger") as mock_logger:
|
||||
with pytest.raises(RuntimeError, match="Fatal loop error"):
|
||||
with session:
|
||||
pass
|
||||
mock_logger.exception.assert_called_with("Error in message processing loop")
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_receive_loop_empty_coverage(streams):
|
||||
with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
with session:
|
||||
time.sleep(0.3)
|
||||
|
||||
|
||||
@pytest.mark.timeout(2)
|
||||
def test_base_methods_noop(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
session._received_request(MagicMock())
|
||||
session._received_notification(MagicMock())
|
||||
session.send_progress_notification("token", 0.5)
|
||||
session._handle_incoming(MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_session_timeout_retry_6(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(
|
||||
read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1)
|
||||
)
|
||||
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]):
|
||||
with pytest.raises(RuntimeError, match="timeout_broken"):
|
||||
session.send_request(request, MockResult)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user