Compare commits

..

2 Commits

Author SHA1 Message Date
autofix-ci[bot]
cd7c0c1802 [autofix.ci] apply automated fixes 2026-03-11 16:27:44 +00:00
dependabot[bot]
e9271bf6d1 chore(deps): bump the opentelemetry group across 1 directory with 16 updates
Bumps the opentelemetry group with 16 updates in the /api directory:

| Package | From | To |
| --- | --- | --- |
| [opentelemetry-api](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-distro](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-exporter-otlp](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-exporter-otlp-proto-common](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-exporter-otlp-proto-grpc](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-exporter-otlp-proto-http](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-instrumentation](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-instrumentation-celery](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-instrumentation-flask](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-instrumentation-httpx](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-instrumentation-redis](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-instrumentation-sqlalchemy](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |
| [opentelemetry-proto](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-sdk](https://github.com/open-telemetry/opentelemetry-python) | `1.28.0` | `1.40.0` |
| [opentelemetry-semantic-conventions](https://github.com/open-telemetry/opentelemetry-python) | `0.49b0` | `0.61b0` |
| [opentelemetry-util-http](https://github.com/open-telemetry/opentelemetry-python-contrib) | `0.49b0` | `0.61b0` |



Updates `opentelemetry-api` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-distro` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-exporter-otlp` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-exporter-otlp-proto-common` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-exporter-otlp-proto-grpc` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-exporter-otlp-proto-http` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-instrumentation` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-instrumentation-celery` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-instrumentation-flask` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-instrumentation-httpx` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-instrumentation-redis` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-instrumentation-sqlalchemy` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

Updates `opentelemetry-proto` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-sdk` from 1.28.0 to 1.40.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/compare/v1.28.0...v1.40.0)

Updates `opentelemetry-semantic-conventions` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python/commits)

Updates `opentelemetry-util-http` from 0.49b0 to 0.61b0
- [Release notes](https://github.com/open-telemetry/opentelemetry-python-contrib/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-python-contrib/commits)

---
updated-dependencies:
- dependency-name: opentelemetry-api
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-distro
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-exporter-otlp
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-exporter-otlp-proto-common
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-exporter-otlp-proto-grpc
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-exporter-otlp-proto-http
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-instrumentation
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-instrumentation-celery
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-instrumentation-flask
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-instrumentation-httpx
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-instrumentation-redis
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-instrumentation-sqlalchemy
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-proto
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-sdk
  dependency-version: 1.40.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: opentelemetry
- dependency-name: opentelemetry-semantic-conventions
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
- dependency-name: opentelemetry-util-http
  dependency-version: 0.61b0
  dependency-type: direct:production
  dependency-group: opentelemetry
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-11 16:25:36 +00:00
1407 changed files with 10969 additions and 82273 deletions

View File

@@ -17,8 +17,8 @@ jobs:
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4, 5, 6]
shardTotal: [6]
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
defaults:
run:
shell: bash
@@ -72,7 +72,7 @@ jobs:
merge-multiple: true
- name: Merge reports
run: pnpm vitest --merge-reports --reporter=json --reporter=agent --coverage
run: pnpm vitest --merge-reports --coverage --silent=passed-only
- name: Coverage Summary
if: always()

View File

@@ -188,6 +188,7 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word

View File

@@ -17,6 +17,11 @@ class WeaviateConfig(BaseSettings):
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,

View File

@@ -114,7 +114,6 @@ def get_user_tenant(view_func: Callable[P, R]):
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()

View File

@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):

View File

@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@@ -193,8 +193,7 @@ class LLMGenerator:
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
error = str(e)
error_step = "generate rule config"
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -280,8 +279,7 @@ class LLMGenerator:
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
error = str(e)
error_step = "handle unexpected exception"
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""

View File

@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
except ValueError:
raise
except Exception:
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):

View File

@@ -113,26 +113,17 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential, as CredentialType or str; str values
are normalized via CredentialType.of and may raise ValueError for invalid values.
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
empty list for CredentialType.UNAUTHORIZED or missing schemas.
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
Raises ValueError for invalid credential types.
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if isinstance(credential_type, str):
credential_type = CredentialType.of(credential_type)
if credential_type == CredentialType.OAUTH2:
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
if credential_type == CredentialType.UNAUTHORIZED:
return []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:

View File

@@ -137,7 +137,6 @@ class ToolFileManager:
session.add(tool_file)
session.commit()
session.refresh(tool_file)
return tool_file

View File

@@ -276,4 +276,7 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
return super().is_empty() and not self.tool_call_id
if not super().is_empty() and not self.tool_call_id:
return False
return True

View File

@@ -4,8 +4,7 @@ class InvokeError(ValueError):
description: str | None = None
def __init__(self, description: str | None = None):
if description is not None:
self.description = description
self.description = description
def __str__(self):
return self.description or self.__class__.__name__

View File

@@ -282,8 +282,7 @@ class ModelProviderFactory:
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
if model_type:
simple_provider_schema.models = all_model_type_models
simple_provider_schema.models.extend(all_model_type_models)
providers.append(simple_provider_schema)

View File

@@ -7,11 +7,11 @@ dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.2",
"beautifulsoup4==4.12.2",
"beautifulsoup4==4.14.3",
"boto3==1.42.65",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.5.2",
"celery~=5.6.2",
"charset-normalizer>=3.4.4",
"flask~=3.1.2",
"flask-compress>=1.17,<1.24",
@@ -35,30 +35,30 @@ dependencies = [
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.7.16",
"markdown~=3.8.1",
"markdown~=3.10.2",
"mlflow-skinny>=3.0.0",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
"opentelemetry-exporter-otlp-proto-common==1.28.0",
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
"opentelemetry-exporter-otlp-proto-http==1.28.0",
"opentelemetry-instrumentation==0.49b0",
"opentelemetry-instrumentation-celery==0.49b0",
"opentelemetry-instrumentation-flask==0.49b0",
"opentelemetry-instrumentation-httpx==0.49b0",
"opentelemetry-instrumentation-redis==0.49b0",
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
"opentelemetry-exporter-otlp==1.40.0",
"opentelemetry-exporter-otlp-proto-common==1.40.0",
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
"opentelemetry-exporter-otlp-proto-http==1.40.0",
"opentelemetry-instrumentation==0.61b0",
"opentelemetry-instrumentation-celery==0.61b0",
"opentelemetry-instrumentation-flask==0.61b0",
"opentelemetry-instrumentation-httpx==0.61b0",
"opentelemetry-instrumentation-redis==0.61b0",
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
"opentelemetry-propagator-b3==1.40.0",
"opentelemetry-proto==1.28.0",
"opentelemetry-sdk==1.28.0",
"opentelemetry-semantic-conventions==0.49b0",
"opentelemetry-util-http==0.49b0",
"pandas[excel,output-formatting,performance]~=2.2.2",
"opentelemetry-proto==1.40.0",
"opentelemetry-sdk==1.40.0",
"opentelemetry-semantic-conventions==0.61b0",
"opentelemetry-util-http==0.61b0",
"pandas[excel,output-formatting,performance]~=3.0.1",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"pycryptodome==3.23.0",
@@ -66,22 +66,22 @@ dependencies = [
"pydantic-extra-types~=2.11.0",
"pydantic-settings~=2.13.1",
"pyjwt~=2.11.0",
"pypdfium2==5.2.0",
"pypdfium2==5.6.0",
"python-docx~=1.2.0",
"python-dotenv==1.0.1",
"python-dotenv==1.2.2",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.3.0",
"resend~=2.9.0",
"sentry-sdk[flask]~=2.28.0",
"resend~=2.23.0",
"sentry-sdk[flask]~=2.54.0",
"sqlalchemy~=2.0.29",
"starlette==0.49.1",
"starlette==0.52.1",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3",
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
"yarl~=1.23.0",
"webvtt-py~=0.5.1",
"sseclient-py~=1.8.0",
"sseclient-py~=1.9.0",
"httpx-sse~=0.4.0",
"sendgrid~=6.12.3",
"flask-restx~=1.3.2",
@@ -120,7 +120,7 @@ dev = [
"pytest-cov~=7.0.0",
"pytest-env~=1.1.3",
"pytest-mock~=3.15.1",
"testcontainers~=4.13.2",
"testcontainers~=4.14.1",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=6.2.0",

View File

@@ -60,6 +60,7 @@ VECTOR_STORE=weaviate
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word

View File

@@ -1,313 +0,0 @@
"""
Unit tests for inner_api plugin endpoints
Tests endpoint structure (method existence) for all plugin APIs, plus
handler-level logic tests for representative non-streaming endpoints.
Auth/setup decorators are tested separately in test_auth_wraps.py;
handler tests use inspect.unwrap() to bypass them.
"""
import inspect
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.inner_api.plugin.plugin import (
PluginFetchAppInfoApi,
PluginInvokeAppApi,
PluginInvokeEncryptApi,
PluginInvokeLLMApi,
PluginInvokeLLMWithStructuredOutputApi,
PluginInvokeModerationApi,
PluginInvokeParameterExtractorNodeApi,
PluginInvokeQuestionClassifierNodeApi,
PluginInvokeRerankApi,
PluginInvokeSpeech2TextApi,
PluginInvokeSummaryApi,
PluginInvokeTextEmbeddingApi,
PluginInvokeToolApi,
PluginInvokeTTSApi,
PluginUploadFileRequestApi,
)
def _extract_raw_post(cls):
"""Extract the raw post() method from a plugin endpoint class.
Plugin endpoint methods are wrapped by several decorators (get_user_tenant,
setup_required, plugin_inner_api_only, plugin_data). These decorators
use @wraps where possible. This helper ensures we retrieve the original
post(self, user_model, tenant_model, payload) function by unwrapping
and, if necessary, walking the closure of the innermost wrapper.
"""
bottom = inspect.unwrap(cls.post)
# If unwrap() didn't get us to the raw function (e.g. if a decorator
# missed @wraps), try to extract it from the closure if it looks like
# a plugin_data or similar wrapper that closes over 'view_func'.
if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars:
try:
idx = bottom.__code__.co_freevars.index("view_func")
return bottom.__closure__[idx].cell_contents
except (AttributeError, TypeError, IndexError):
pass
return bottom
class TestPluginInvokeLLMApi:
"""Test PluginInvokeLLMApi endpoint structure"""
@pytest.fixture
def api_instance(self):
return PluginInvokeLLMApi()
def test_has_post_method(self, api_instance):
"""Test that endpoint has post method"""
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeLLMWithStructuredOutputApi:
"""Test PluginInvokeLLMWithStructuredOutputApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeLLMWithStructuredOutputApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeTextEmbeddingApi:
"""Test PluginInvokeTextEmbeddingApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeTextEmbeddingApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeRerankApi:
"""Test PluginInvokeRerankApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeRerankApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeTTSApi:
"""Test PluginInvokeTTSApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeTTSApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeSpeech2TextApi:
"""Test PluginInvokeSpeech2TextApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeSpeech2TextApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeModerationApi:
"""Test PluginInvokeModerationApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeModerationApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeToolApi:
"""Test PluginInvokeToolApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeToolApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeParameterExtractorNodeApi:
"""Test PluginInvokeParameterExtractorNodeApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeParameterExtractorNodeApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeQuestionClassifierNodeApi:
"""Test PluginInvokeQuestionClassifierNodeApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeQuestionClassifierNodeApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeAppApi:
"""Test PluginInvokeAppApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeAppApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeEncryptApi:
"""Test PluginInvokeEncryptApi endpoint structure and handler logic"""
@pytest.fixture
def api_instance(self):
return PluginInvokeEncryptApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask):
"""Test that post() delegates to PluginEncrypter and returns model_dump output"""
# Arrange
mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"}
mock_tenant = MagicMock()
mock_user = MagicMock()
mock_payload = MagicMock()
# Act — extract raw post() bypassing all decorators including plugin_data
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload)
assert result["data"] == {"encrypted": "data"}
assert result.get("error") == ""
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask):
"""Test that post() catches exceptions and returns error response"""
# Arrange
mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed")
mock_tenant = MagicMock()
mock_user = MagicMock()
mock_payload = MagicMock()
# Act
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
assert "encrypt failed" in result["error"]
class TestPluginInvokeSummaryApi:
"""Test PluginInvokeSummaryApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeSummaryApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginUploadFileRequestApi:
"""Test PluginUploadFileRequestApi endpoint structure and handler logic"""
@pytest.fixture
def api_instance(self):
return PluginUploadFileRequestApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin")
def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask):
"""Test that post() generates a signed URL and returns it"""
# Arrange
mock_get_url.return_value = "https://storage.example.com/signed-upload-url"
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_user = MagicMock()
mock_user.id = "user-id"
mock_payload = MagicMock()
mock_payload.filename = "test.pdf"
mock_payload.mimetype = "application/pdf"
# Act
raw_post = _extract_raw_post(PluginUploadFileRequestApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
mock_get_url.assert_called_once_with(
filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id"
)
assert result["data"]["url"] == "https://storage.example.com/signed-upload-url"
class TestPluginFetchAppInfoApi:
"""Test PluginFetchAppInfoApi endpoint structure and handler logic"""
@pytest.fixture
def api_instance(self):
return PluginFetchAppInfoApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation")
def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask):
"""Test that post() fetches app info and returns it"""
# Arrange
mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"}
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_user = MagicMock()
mock_payload = MagicMock()
mock_payload.app_id = "app-123"
# Act
raw_post = _extract_raw_post(PluginFetchAppInfoApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id")
assert result["data"] == {"app_name": "My App", "mode": "chat"}

View File

@@ -1,305 +0,0 @@
"""
Unit tests for inner_api plugin decorators
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.plugin.wraps import (
TenantUserPayload,
get_user,
get_user_tenant,
plugin_data,
)
class TestTenantUserPayload:
"""Test TenantUserPayload Pydantic model"""
def test_valid_payload(self):
"""Test valid payload passes validation"""
data = {"tenant_id": "tenant123", "user_id": "user456"}
payload = TenantUserPayload.model_validate(data)
assert payload.tenant_id == "tenant123"
assert payload.user_id == "user456"
def test_missing_tenant_id(self):
"""Test missing tenant_id raises ValidationError"""
with pytest.raises(ValidationError):
TenantUserPayload.model_validate({"user_id": "user456"})
def test_missing_user_id(self):
"""Test missing user_id raises ValidationError"""
with pytest.raises(ValidationError):
TenantUserPayload.model_validate({"tenant_id": "tenant123"})
class TestGetUser:
"""Test get_user function"""
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
"""Test returning existing user when found by ID"""
# Arrange
mock_user = MagicMock()
mock_user.id = "user123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# Act
with app.app_context():
result = get_user("tenant123", "user123")
# Assert
assert result == mock_user
mock_session.query.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_anonymous_user_by_session_id(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test returning existing anonymous user by session_id"""
# Arrange
mock_user = MagicMock()
mock_user.session_id = "anonymous_session"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# Act
with app.app_context():
result = get_user("tenant123", "anonymous_session")
# Assert
assert result == mock_user
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
"""Test creating new user when not found in database"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = None
mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user
# Act
with app.app_context():
result = get_user("tenant123", "user123")
# Assert
assert result == mock_new_user
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_use_default_session_id_when_user_id_none(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test using default session ID when user_id is None"""
# Arrange
mock_user = MagicMock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# Act
with app.app_context():
result = get_user("tenant123", None)
# Assert
assert result == mock_user
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_raise_error_on_database_exception(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test raising ValueError when database operation fails"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.side_effect = Exception("Database error")
# Act & Assert
with app.app_context():
with pytest.raises(ValueError, match="user not found"):
get_user("tenant123", "user123")
class TestGetUserTenant:
"""Test get_user_tenant decorator"""
@patch("controllers.inner_api.plugin.wraps.Tenant")
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
"""Test that decorator injects tenant_model and user_model into kwargs"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return {"tenant": tenant_model, "user": user_model}
mock_tenant = MagicMock()
mock_tenant.id = "tenant123"
mock_user = MagicMock()
mock_user.id = "user456"
# Act
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()
# Assert
assert result["tenant"] == mock_tenant
assert result["user"] == mock_user
def test_should_raise_error_when_tenant_id_missing(self, app: Flask):
"""Test that Pydantic ValidationError is raised when tenant_id is missing from payload"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return "success"
# Act & Assert - Pydantic validates payload before manual check
with app.test_request_context(json={"user_id": "user456"}):
with pytest.raises(ValidationError):
protected_view()
def test_should_raise_error_when_tenant_not_found(self, app: Flask):
"""Test that ValueError is raised when tenant is not found"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return "success"
# Act & Assert
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="tenant not found"):
protected_view()
@patch("controllers.inner_api.plugin.wraps.Tenant")
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
"""Test that default session ID is used when user_id is empty string"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return {"tenant": tenant_model, "user": user_model}
mock_tenant = MagicMock()
mock_tenant.id = "tenant123"
mock_user = MagicMock()
# Act - use empty string for user_id to trigger default logic
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()
# Assert
assert result["tenant"] == mock_tenant
assert result["user"] == mock_user
from models.model import DefaultEndUserSessionID
mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID)
class PluginTestPayload:
"""Simple test payload class"""
def __init__(self, data: dict):
self.value = data.get("value")
@classmethod
def model_validate(cls, data: dict):
return cls(data)
class TestPluginData:
"""Test plugin_data decorator"""
def test_should_inject_valid_payload(self, app: Flask):
"""Test that valid payload is injected into kwargs"""
# Arrange
@plugin_data(payload_type=PluginTestPayload)
def protected_view(payload, **kwargs):
return payload
# Act
with app.test_request_context(json={"value": "test_data"}):
result = protected_view()
# Assert
assert result.value == "test_data"
def test_should_raise_error_on_invalid_json(self, app: Flask):
"""Test that ValueError is raised when JSON parsing fails"""
# Arrange
@plugin_data(payload_type=PluginTestPayload)
def protected_view(payload, **kwargs):
return payload
# Act & Assert - Malformed JSON triggers ValueError
with app.test_request_context(data="not valid json", content_type="application/json"):
with pytest.raises(ValueError):
protected_view()
def test_should_raise_error_on_invalid_payload(self, app: Flask):
"""Test that ValueError is raised when payload validation fails"""
# Arrange
class InvalidPayload:
@classmethod
def model_validate(cls, data: dict):
raise Exception("Validation failed")
@plugin_data(payload_type=InvalidPayload)
def protected_view(payload, **kwargs):
return payload
# Act & Assert
with app.test_request_context(json={"data": "test"}):
with pytest.raises(ValueError, match="invalid payload"):
protected_view()
def test_should_work_as_parameterized_decorator(self, app: Flask):
"""Test that decorator works when used with parentheses"""
# Arrange
@plugin_data(payload_type=PluginTestPayload)
def protected_view(payload, **kwargs):
return payload
# Act
with app.test_request_context(json={"value": "parameterized"}):
result = protected_view()
# Assert
assert result.value == "parameterized"

View File

@@ -1,309 +0,0 @@
"""
Unit tests for inner_api auth decorators
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
from configs import dify_config
from controllers.inner_api.wraps import (
billing_inner_api_only,
enterprise_inner_api_only,
enterprise_inner_api_user_auth,
plugin_inner_api_only,
)
class TestBillingInnerApiOnly:
"""Test billing_inner_api_only decorator"""
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
"""Test that valid API key allows access when INNER_API is enabled"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
result = protected_view()
# Assert
assert result == "success"
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
"""Test that 404 is returned when INNER_API is disabled"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context():
with patch.object(dify_config, "INNER_API", False):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404
def test_should_return_401_when_api_key_missing(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
def test_should_return_401_when_api_key_invalid(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
class TestEnterpriseInnerApiOnly:
"""Test enterprise_inner_api_only decorator"""
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
"""Test that valid API key allows access when INNER_API is enabled"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
result = protected_view()
# Assert
assert result == "success"
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
"""Test that 404 is returned when INNER_API is disabled"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context():
with patch.object(dify_config, "INNER_API", False):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404
def test_should_return_401_when_api_key_missing(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
def test_should_return_401_when_api_key_invalid(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
class TestEnterpriseInnerApiUserAuth:
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
def test_should_pass_through_when_inner_api_disabled(self, app: Flask):
"""Test that request passes through when INNER_API is disabled"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act
with app.test_request_context():
with patch.object(dify_config, "INNER_API", False):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_pass_through_when_authorization_header_missing(self, app: Flask):
"""Test that request passes through when Authorization header is missing"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act
with app.test_request_context(headers={}):
with patch.object(dify_config, "INNER_API", True):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_pass_through_when_authorization_format_invalid(self, app: Flask):
"""Test that request passes through when Authorization format is invalid (no colon)"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act
with app.test_request_context(headers={"Authorization": "invalid_format"}):
with patch.object(dify_config, "INNER_API", True):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask):
"""Test that request passes through when HMAC signature is invalid"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act - use wrong signature
with app.test_request_context(
headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"}
):
with patch.object(dify_config, "INNER_API", True):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_inject_user_when_hmac_signature_valid(self, app: Flask):
"""Test that user is injected when HMAC signature is valid"""
# Arrange
from base64 import b64encode
from hashlib import sha1
from hmac import new as hmac_new
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user")
# Calculate valid HMAC signature
user_id = "user123"
inner_api_key = "valid_key"
data_to_sign = f"DIFY {user_id}"
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
valid_signature = b64encode(signature.digest()).decode("utf-8")
# Create mock user
mock_user = MagicMock()
mock_user.id = user_id
# Act
with app.test_request_context(
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
):
with patch.object(dify_config, "INNER_API", True):
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = mock_user
result = protected_view()
# Assert
assert result == mock_user
class TestPluginInnerApiOnly:
"""Test plugin_inner_api_only decorator"""
def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask):
"""Test that valid API key allows access when PLUGIN_DAEMON_KEY is set"""
# Arrange
@plugin_inner_api_only
def protected_view():
return "success"
# Act
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}):
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
result = protected_view()
# Assert
assert result == "success"
def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask):
"""Test that 404 is returned when PLUGIN_DAEMON_KEY is not set"""
# Arrange
@plugin_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context():
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404
def test_should_return_404_when_api_key_invalid(self, app: Flask):
"""Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)"""
# Arrange
@plugin_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404

View File

@@ -1,206 +0,0 @@
"""
Unit tests for inner_api mail module
"""
from unittest.mock import patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.mail import (
BaseMail,
BillingMail,
EnterpriseMail,
InnerMailPayload,
)
class TestInnerMailPayload:
"""Test InnerMailPayload Pydantic model"""
def test_valid_payload_with_all_fields(self):
"""Test valid payload with all fields passes validation"""
data = {
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
"substitutions": {"key": "value"},
}
payload = InnerMailPayload.model_validate(data)
assert payload.to == ["test@example.com"]
assert payload.subject == "Test Subject"
assert payload.body == "Test Body"
assert payload.substitutions == {"key": "value"}
def test_valid_payload_without_substitutions(self):
"""Test valid payload without optional substitutions"""
data = {
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
payload = InnerMailPayload.model_validate(data)
assert payload.to == ["test@example.com"]
assert payload.subject == "Test Subject"
assert payload.body == "Test Body"
assert payload.substitutions is None
def test_empty_to_list_fails_validation(self):
"""Test that empty 'to' list fails validation due to min_length=1"""
data = {
"to": [],
"subject": "Test Subject",
"body": "Test Body",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
def test_multiple_recipients_allowed(self):
"""Test that multiple recipients are allowed"""
data = {
"to": ["user1@example.com", "user2@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
payload = InnerMailPayload.model_validate(data)
assert len(payload.to) == 2
assert "user1@example.com" in payload.to
assert "user2@example.com" in payload.to
def test_missing_to_field_fails_validation(self):
"""Test that missing 'to' field fails validation"""
data = {
"subject": "Test Subject",
"body": "Test Body",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
def test_missing_subject_fails_validation(self):
"""Test that missing 'subject' field fails validation"""
data = {
"to": ["test@example.com"],
"body": "Test Body",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
def test_missing_body_fails_validation(self):
"""Test that missing 'body' field fails validation"""
data = {
"to": ["test@example.com"],
"subject": "Test Subject",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
class TestBaseMail:
"""Test BaseMail API endpoint"""
@pytest.fixture
def api_instance(self):
"""Create BaseMail API instance"""
return BaseMail()
@patch("controllers.inner_api.mail.send_inner_email_task")
def test_post_sends_email_task(self, mock_task, api_instance, app: Flask):
"""Test that POST sends inner email task"""
# Arrange
mock_task.delay.return_value = None
# Act
with app.test_request_context(
json={
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
):
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
mock_ns.payload = {
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
result = api_instance.post()
# Assert
assert result == ({"message": "success"}, 200)
mock_task.delay.assert_called_once_with(
to=["test@example.com"],
subject="Test Subject",
body="Test Body",
substitutions=None,
)
@patch("controllers.inner_api.mail.send_inner_email_task")
def test_post_with_substitutions(self, mock_task, api_instance, app: Flask):
"""Test that POST sends email with substitutions"""
# Arrange
mock_task.delay.return_value = None
# Act
with app.test_request_context():
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
mock_ns.payload = {
"to": ["test@example.com"],
"subject": "Hello {{name}}",
"body": "Welcome {{name}}!",
"substitutions": {"name": "John"},
}
result = api_instance.post()
# Assert
assert result == ({"message": "success"}, 200)
mock_task.delay.assert_called_once_with(
to=["test@example.com"],
subject="Hello {{name}}",
body="Welcome {{name}}!",
substitutions={"name": "John"},
)
class TestEnterpriseMail:
"""Test EnterpriseMail API endpoint"""
@pytest.fixture
def api_instance(self):
"""Create EnterpriseMail API instance"""
return EnterpriseMail()
def test_has_enterprise_inner_api_only_decorator(self, api_instance):
"""Test that EnterpriseMail has enterprise_inner_api_only decorator"""
# Check method_decorators
from controllers.inner_api.wraps import enterprise_inner_api_only
assert enterprise_inner_api_only in api_instance.method_decorators
def test_has_setup_required_decorator(self, api_instance):
"""Test that EnterpriseMail has setup_required decorator"""
# Check by decorator name instead of object reference
decorator_names = [d.__name__ for d in api_instance.method_decorators]
assert "setup_required" in decorator_names
class TestBillingMail:
"""Test BillingMail API endpoint"""
@pytest.fixture
def api_instance(self):
"""Create BillingMail API instance"""
return BillingMail()
def test_has_billing_inner_api_only_decorator(self, api_instance):
"""Test that BillingMail has billing_inner_api_only decorator"""
# Check method_decorators
from controllers.inner_api.wraps import billing_inner_api_only
assert billing_inner_api_only in api_instance.method_decorators
def test_has_setup_required_decorator(self, api_instance):
"""Test that BillingMail has setup_required decorator"""
# Check by decorator name instead of object reference
decorator_names = [d.__name__ for d in api_instance.method_decorators]
assert "setup_required" in decorator_names

View File

@@ -1,184 +0,0 @@
"""
Unit tests for inner_api workspace module
Tests Pydantic model validation and endpoint handler logic.
Auth/setup decorators are tested separately in test_auth_wraps.py;
handler tests use inspect.unwrap() to bypass them and focus on business logic.
"""
import inspect
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.workspace.workspace import (
EnterpriseWorkspace,
EnterpriseWorkspaceNoOwnerEmail,
WorkspaceCreatePayload,
WorkspaceOwnerlessPayload,
)
class TestWorkspaceCreatePayload:
"""Test WorkspaceCreatePayload Pydantic model validation"""
def test_valid_payload(self):
"""Test valid payload with all fields passes validation"""
data = {
"name": "My Workspace",
"owner_email": "owner@example.com",
}
payload = WorkspaceCreatePayload.model_validate(data)
assert payload.name == "My Workspace"
assert payload.owner_email == "owner@example.com"
def test_missing_name_fails_validation(self):
"""Test that missing name fails validation"""
data = {"owner_email": "owner@example.com"}
with pytest.raises(ValidationError) as exc_info:
WorkspaceCreatePayload.model_validate(data)
assert "name" in str(exc_info.value)
def test_missing_owner_email_fails_validation(self):
"""Test that missing owner_email fails validation"""
data = {"name": "My Workspace"}
with pytest.raises(ValidationError) as exc_info:
WorkspaceCreatePayload.model_validate(data)
assert "owner_email" in str(exc_info.value)
class TestWorkspaceOwnerlessPayload:
"""Test WorkspaceOwnerlessPayload Pydantic model validation"""
def test_valid_payload(self):
"""Test valid payload with name passes validation"""
data = {"name": "My Workspace"}
payload = WorkspaceOwnerlessPayload.model_validate(data)
assert payload.name == "My Workspace"
def test_missing_name_fails_validation(self):
"""Test that missing name fails validation"""
data = {}
with pytest.raises(ValidationError) as exc_info:
WorkspaceOwnerlessPayload.model_validate(data)
assert "name" in str(exc_info.value)
class TestEnterpriseWorkspace:
"""Test EnterpriseWorkspace API endpoint handler logic.
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
and exercise the core business logic directly.
"""
@pytest.fixture
def api_instance(self):
return EnterpriseWorkspace()
def test_has_post_method(self, api_instance):
"""Test that EnterpriseWorkspace has post method"""
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
@patch("controllers.inner_api.workspace.workspace.TenantService")
@patch("controllers.inner_api.workspace.workspace.db")
def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask):
"""Test that post() creates a workspace and assigns the owner account"""
# Arrange
mock_account = MagicMock()
mock_account.email = "owner@example.com"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
now = datetime(2025, 1, 1, 12, 0, 0)
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_tenant.name = "My Workspace"
mock_tenant.plan = "sandbox"
mock_tenant.status = "normal"
mock_tenant.created_at = now
mock_tenant.updated_at = now
mock_tenant_svc.create_tenant.return_value = mock_tenant
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
unwrapped_post = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"}
result = unwrapped_post(api_instance)
# Assert
assert result["message"] == "enterprise workspace created."
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["name"] == "My Workspace"
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
mock_event.send.assert_called_once_with(mock_tenant)
@patch("controllers.inner_api.workspace.workspace.db")
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
"""Test that post() returns 404 when the owner account does not exist"""
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act
unwrapped_post = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"}
result = unwrapped_post(api_instance)
# Assert
assert result == ({"message": "owner account not found."}, 404)
class TestEnterpriseWorkspaceNoOwnerEmail:
"""Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic.
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
and exercise the core business logic directly.
"""
@pytest.fixture
def api_instance(self):
return EnterpriseWorkspaceNoOwnerEmail()
def test_has_post_method(self, api_instance):
"""Test that endpoint has post method"""
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
@patch("controllers.inner_api.workspace.workspace.TenantService")
def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask):
"""Test that post() creates a workspace without an owner and returns expected fields"""
# Arrange
now = datetime(2025, 1, 1, 12, 0, 0)
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_tenant.name = "My Workspace"
mock_tenant.encrypt_public_key = "pub-key"
mock_tenant.plan = "sandbox"
mock_tenant.status = "normal"
mock_tenant.custom_config = None
mock_tenant.created_at = now
mock_tenant.updated_at = now
mock_tenant_svc.create_tenant.return_value = mock_tenant
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
unwrapped_post = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
mock_ns.payload = {"name": "My Workspace"}
result = unwrapped_post(api_instance)
# Assert
assert result["message"] == "enterprise workspace created."
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["encrypt_public_key"] == "pub-key"
assert result["tenant"]["custom_config"] == {}
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_event.send.assert_called_once_with(mock_tenant)

View File

@@ -1,80 +0,0 @@
import pytest
class DummyTool:
def __init__(self, name):
self.name = name
class DummyPromptEntity:
def __init__(self, first_prompt):
self.first_prompt = first_prompt
class DummyAgentConfig:
def __init__(self, prompt_entity=None):
self.prompt = prompt_entity
class DummyAppConfig:
def __init__(self, agent=None):
self.agent = agent
class DummyScratchpadUnit:
def __init__(
self,
final=False,
thought=None,
action_str=None,
observation=None,
agent_response=None,
):
self._final = final
self.thought = thought
self.action_str = action_str
self.observation = observation
self.agent_response = agent_response
def is_final(self):
return self._final
@pytest.fixture
def dummy_tool_factory():
def _factory(name):
return DummyTool(name)
return _factory
@pytest.fixture
def dummy_prompt_entity_factory():
def _factory(first_prompt):
return DummyPromptEntity(first_prompt)
return _factory
@pytest.fixture
def dummy_agent_config_factory():
def _factory(prompt_entity=None):
return DummyAgentConfig(prompt_entity)
return _factory
@pytest.fixture
def dummy_app_config_factory():
def _factory(agent=None):
return DummyAppConfig(agent)
return _factory
@pytest.fixture
def dummy_scratchpad_unit_factory():
def _factory(**kwargs):
return DummyScratchpadUnit(**kwargs)
return _factory

View File

@@ -1,255 +1,70 @@
"""Unit tests for CotAgentOutputParser.
Verifies expected parsing behavior for streaming content and JSON payloads,
including edge cases such as empty/non-string content and malformed JSON.
Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real
model output structures. Implementation under test:
core.agent.output_parser.cot_output_parser.CotAgentOutputParser.
"""
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from collections.abc import Generator
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
@pytest.fixture
def mock_action_class(mocker):
mock_action = MagicMock()
mocker.patch(
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
mock_action,
)
return mock_action
@pytest.fixture
def usage_dict():
return {}
@pytest.fixture
def make_chunk():
def _make_chunk(content=None, usage=None):
delta = SimpleNamespace(
message=SimpleNamespace(content=content),
usage=usage,
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
for i in range(len(text)):
yield LLMResultChunk(
model="model",
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
)
return SimpleNamespace(delta=delta)
return _make_chunk
# ============================================================
# Test Suite
# ============================================================
def test_cot_output_parser():
test_cases = [
{
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with json
{
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with JSON
{
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# list
{
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block
{
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block and json
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
]
class TestCotAgentOutputParser:
"""Validate CotAgentOutputParser streaming + JSON parsing behavior.
Lifecycle: no explicit setup/teardown; relies on pytest fixtures for
lightweight chunk/action doubles. Invariants: non-string/empty content
yields no output, usage gets recorded when provided, and valid action JSON
results in Action instantiation. Usage: invoke via pytest (e.g.,
`pytest -k TestCotAgentOutputParser`).
"""
# --------------------------------------------------------
# Basic streaming & usage
# --------------------------------------------------------
def test_stream_plain_text(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("hello world")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert "".join(result) == "hello world"
def test_stream_empty_string(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result == []
def test_stream_none_content(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk(None)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result == []
@pytest.mark.parametrize("content", [123, 12.5, [], {}, object()])
def test_non_string_content(self, make_chunk, usage_dict, content) -> None:
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result == []
def test_usage_update(self, make_chunk, usage_dict) -> None:
usage_data = {"tokens": 99}
chunks = [make_chunk("abc", usage=usage_data)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert usage_dict["usage"] == usage_data
# --------------------------------------------------------
# JSON parsing (direct + streaming)
# --------------------------------------------------------
def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
content = '{"action": "search", "input": "query"}'
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="search", action_input="query")
def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None:
content = '[{"action": "lookup", "input": "abc"}]'
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None:
content = '{"foo": "bar"}'
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
# Expect the serialized JSON to be yielded as a single element.
assert result == [json.dumps({"foo": "bar"})]
def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None:
content = "{invalid json}"
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert any("invalid json" in str(r) for r in result)
def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None:
chunks = [
make_chunk('{"action": '),
make_chunk('"multi", '),
make_chunk('"input": "step"}'),
]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="multi", action_input="step")
def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk('{"foo": "bar"')]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert all(isinstance(item, str) for item in result)
assert any('{"foo": "bar"' in item for item in result)
# --------------------------------------------------------
# Code block JSON extraction
# --------------------------------------------------------
def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
content = """```json
{"action": "lookup", "input": "abc"}
```"""
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None:
# Multiple JSON objects inside single code fence (invalid combined JSON)
# Parser should safely ignore invalid combined block
content = """```json
{"action": "a1", "input": "x"}
{"action": "a2", "input": "y"}
```"""
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
# No valid parsed action expected due to invalid combined JSON
assert mock_action_class.call_count == 0
assert isinstance(result, list)
def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None:
content = """```json
{invalid}
```"""
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert result
def test_unclosed_code_block(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk('```json {"a":1}')]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert all(isinstance(item, str) for item in result)
assert any('```json {"a":1}' in item for item in result)
# --------------------------------------------------------
# Action / Thought prefix handling
# --------------------------------------------------------
@pytest.mark.parametrize(
"content",
[
" action: something",
" ACTION: something",
" thought: reasoning",
" THOUGHT: reasoning",
],
)
def test_prefix_handling(self, make_chunk, usage_dict, content) -> None:
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
joined = "".join(str(item) for item in result)
expected_word = "something" if "action:" in content.lower() else "reasoning"
assert expected_word in joined
assert "action:" not in joined.lower()
assert "thought:" not in joined.lower()
def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("xaction: test")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert "x" in "".join(map(str, result))
# --------------------------------------------------------
# Mixed streaming scenarios
# --------------------------------------------------------
def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None:
content = 'start {"action": "mix", "input": "1"} end'
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
# JSON action should be parsed
mock_action_class.assert_called_once()
# Ensure surrounding text is streamed (character-level)
joined = "".join(str(r) for r in result if not isinstance(r, MagicMock))
assert "start" in joined
assert "end" in joined
def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None:
content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```'
chunks = [make_chunk(content)]
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert mock_action_class.call_count == 2
def test_backtick_noise(self, make_chunk, usage_dict) -> None:
chunks = [make_chunk("text with ` random ` backticks")]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert "text with" in "".join(result)
# --------------------------------------------------------
# Boundary & edge inputs
# --------------------------------------------------------
@pytest.mark.parametrize(
"content",
[
"```",
"{",
"}",
"```json",
"action:",
"thought:",
" ",
],
)
def test_edge_inputs(self, make_chunk, usage_dict, content) -> None:
chunks = [make_chunk(content)]
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
assert all(isinstance(item, str) for item in result)
joined = "".join(result)
if content == " ":
assert result == [] or joined == content
if content in {"```", "{", "}", "```json"}:
assert content in joined
if content.lower() in {"action:", "thought:"}:
assert "action:" not in joined.lower()
assert "thought:" not in joined.lower()
parser = CotAgentOutputParser()
usage_dict = {}
for test_case in test_cases:
# mock llm_response as a generator by text
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
results = parser.handle_react_stream_output(llm_response, usage_dict)
output = ""
for result in results:
if isinstance(result, str):
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
assert result.to_dict() == test_case["action"]
output += json.dumps(result.to_dict())
if test_case["output"]:
assert output == test_case["output"]

View File

@@ -1,174 +0,0 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from core.agent.strategy.base import BaseAgentStrategy
class DummyStrategy(BaseAgentStrategy):
"""
Concrete implementation for testing BaseAgentStrategy
"""
def __init__(self, return_values=None, raise_exception=None):
self.return_values = return_values or []
self.raise_exception = raise_exception
self.received_args = None
def _invoke(
self,
params,
user_id,
conversation_id=None,
app_id=None,
message_id=None,
credentials=None,
) -> Generator:
self.received_args = (
params,
user_id,
conversation_id,
app_id,
message_id,
credentials,
)
if self.raise_exception:
raise self.raise_exception
yield from self.return_values
class TestBaseAgentStrategyInstantiation:
def test_cannot_instantiate_abstract_class(self) -> None:
with pytest.raises(TypeError):
BaseAgentStrategy()
class TestBaseAgentStrategyInvoke:
@pytest.fixture
def mock_message(self):
return MagicMock(name="AgentInvokeMessage")
@pytest.fixture
def mock_credentials(self):
return MagicMock(name="InvokeCredentials")
@pytest.mark.parametrize(
("params", "user_id", "conversation_id", "app_id", "message_id"),
[
({"key": "value"}, "user1", "conv1", "app1", "msg1"),
({}, "user2", None, None, None),
({"a": 1}, "", "", "", ""),
({"nested": {"x": 1}}, "user3", None, "app3", None),
],
)
def test_invoke_success(
self,
mock_message,
mock_credentials,
params,
user_id,
conversation_id,
app_id,
message_id,
) -> None:
# Arrange
strategy = DummyStrategy(return_values=[mock_message])
# Act
result = list(
strategy.invoke(
params=params,
user_id=user_id,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
credentials=mock_credentials,
)
)
# Assert
assert result == [mock_message]
assert strategy.received_args == (
params,
user_id,
conversation_id,
app_id,
message_id,
mock_credentials,
)
def test_invoke_multiple_yields(self, mock_message) -> None:
# Arrange
messages = [mock_message, MagicMock(), MagicMock()]
strategy = DummyStrategy(return_values=messages)
# Act
result = list(strategy.invoke(params={}, user_id="user"))
# Assert
assert result == messages
def test_invoke_empty_generator(self) -> None:
# Arrange
strategy = DummyStrategy(return_values=[])
# Act
result = list(strategy.invoke(params={}, user_id="user"))
# Assert
assert result == []
def test_invoke_propagates_exception(self) -> None:
# Arrange
strategy = DummyStrategy(raise_exception=ValueError("failure"))
# Act & Assert
with pytest.raises(ValueError, match="failure"):
list(strategy.invoke(params={}, user_id="user"))
@pytest.mark.parametrize(
"invalid_params",
[
None,
"",
123,
[],
],
)
def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None:
"""
Base class does not validate types — ensure pass-through behavior
"""
strategy = DummyStrategy(return_values=[])
result = list(strategy.invoke(params=invalid_params, user_id="user"))
assert result == []
def test_invoke_none_user_id(self) -> None:
strategy = DummyStrategy(return_values=[])
result = list(strategy.invoke(params={}, user_id=None))
assert result == []
class TestBaseAgentStrategyGetParameters:
def test_get_parameters_default_empty_list(self) -> None:
strategy = DummyStrategy()
result = strategy.get_parameters()
assert isinstance(result, list)
assert result == []
def test_get_parameters_returns_new_list_each_time(self) -> None:
strategy = DummyStrategy()
first = strategy.get_parameters()
second = strategy.get_parameters()
assert first == second == []
assert first is not second

View File

@@ -1,272 +0,0 @@
# File: tests/unit_tests/core/agent/strategy/test_plugin.py
from unittest.mock import MagicMock
import pytest
from core.agent.strategy.plugin import PluginAgentStrategy
# ============================================================
# Fixtures
# ============================================================
@pytest.fixture
def mock_parameter():
def _factory(name="param", return_value="initialized"):
param = MagicMock()
param.name = name
param.init_frontend_parameter = MagicMock(return_value=return_value)
return param
return _factory
@pytest.fixture
def mock_declaration(mock_parameter):
param1 = mock_parameter("param1", "init1")
param2 = mock_parameter("param2", "init2")
identity = MagicMock()
identity.provider = "provider_x"
identity.name = "strategy_x"
declaration = MagicMock()
declaration.parameters = [param1, param2]
declaration.identity = identity
return declaration
@pytest.fixture
def strategy(mock_declaration):
return PluginAgentStrategy(
tenant_id="tenant_123",
declaration=mock_declaration,
meta_version="v1",
)
# ============================================================
# Initialization Tests
# ============================================================
class TestPluginAgentStrategyInitialization:
def test_init_sets_attributes(self, mock_declaration) -> None:
strategy = PluginAgentStrategy(
tenant_id="tenant_test",
declaration=mock_declaration,
meta_version="meta_v",
)
assert strategy.tenant_id == "tenant_test"
assert strategy.declaration == mock_declaration
assert strategy.meta_version == "meta_v"
def test_init_meta_version_none(self, mock_declaration) -> None:
strategy = PluginAgentStrategy(
tenant_id="tenant_test",
declaration=mock_declaration,
meta_version=None,
)
assert strategy.meta_version is None
# ============================================================
# get_parameters Tests
# ============================================================
class TestGetParameters:
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
result = strategy.get_parameters()
assert result == mock_declaration.parameters
# ============================================================
# initialize_parameters Tests
# ============================================================
class TestInitializeParameters:
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
params = {"param1": "value1"}
result = strategy.initialize_parameters(params.copy())
assert result["param1"] == "init1"
assert result["param2"] == "init2"
mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1")
mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None)
@pytest.mark.parametrize(
"input_params",
[
{},
{"param1": None},
{"param1": ""},
{"param1": 0},
{"param1": []},
{"param1": {}, "param2": "value"},
],
)
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
result = strategy.initialize_parameters(input_params.copy())
for param in strategy.declaration.parameters:
assert param.name in result
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
with pytest.raises(AttributeError):
strategy.initialize_parameters(None)
# ============================================================
# _invoke Tests
# ============================================================
class TestInvoke:
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mock_convert = mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={"converted": True},
)
result = list(
strategy._invoke(
params={"param1": "value"},
user_id="user_1",
conversation_id="conv_1",
app_id="app_1",
message_id="msg_1",
credentials=None,
)
)
assert result == ["msg1", "msg2"]
mock_convert.assert_called_once()
mock_manager.invoke.assert_called_once()
call_kwargs = mock_manager.invoke.call_args.kwargs
assert call_kwargs["tenant_id"] == "tenant_123"
assert call_kwargs["user_id"] == "user_1"
assert call_kwargs["agent_provider"] == "provider_x"
assert call_kwargs["agent_strategy"] == "strategy_x"
assert call_kwargs["agent_params"] == {"converted": True}
assert call_kwargs["conversation_id"] == "conv_1"
assert call_kwargs["app_id"] == "app_1"
assert call_kwargs["message_id"] == "msg_1"
assert call_kwargs["context"] is not None
def test_invoke_with_credentials(self, strategy, mocker) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter([]))
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={},
)
# Patch PluginInvokeContext to bypass pydantic validation
mock_context = MagicMock()
mocker.patch(
"core.agent.strategy.plugin.PluginInvokeContext",
return_value=mock_context,
)
credentials = MagicMock()
result = list(
strategy._invoke(
params={},
user_id="user_1",
credentials=credentials,
)
)
assert result == []
mock_manager.invoke.assert_called_once()
@pytest.mark.parametrize(
("conversation_id", "app_id", "message_id"),
[
(None, None, None),
("conv", None, None),
(None, "app", None),
(None, None, "msg"),
],
)
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
mock_manager = MagicMock()
mock_manager.invoke = MagicMock(return_value=iter([]))
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={},
)
result = list(
strategy._invoke(
params={},
user_id="user_1",
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
)
assert result == []
mock_manager.invoke.assert_called_once()
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=MagicMock(),
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
side_effect=ValueError("conversion failed"),
)
with pytest.raises(ValueError):
list(strategy._invoke(params={}, user_id="user_1"))
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
mock_manager = MagicMock()
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
mocker.patch(
"core.agent.strategy.plugin.PluginAgentClient",
return_value=mock_manager,
)
mocker.patch(
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
return_value={},
)
with pytest.raises(RuntimeError):
list(strategy._invoke(params={}, user_id="user_1"))

View File

@@ -1,802 +0,0 @@
import json
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
import core.agent.base_agent_runner as module
from core.agent.base_agent_runner import BaseAgentRunner
# ==========================================================
# Fixtures
# ==========================================================
@pytest.fixture
def mock_db_session(mocker):
session = mocker.MagicMock()
mocker.patch.object(module.db, "session", session)
return session
@pytest.fixture
def runner(mocker, mock_db_session):
r = BaseAgentRunner.__new__(BaseAgentRunner)
r.tenant_id = "tenant"
r.user_id = "user"
r.agent_thought_count = 0
r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1")
r.app_config = mocker.MagicMock()
r.app_config.app_id = "app1"
r.app_config.agent = None
r.dataset_tools = []
r.application_generate_entity = mocker.MagicMock(invoke_from="test")
r._current_thoughts = []
return r
# ==========================================================
# _repack_app_generate_entity
# ==========================================================
class TestRepack:
def test_sets_empty_if_none(self, runner, mocker):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = None
result = runner._repack_app_generate_entity(entity)
assert result.app_config.prompt_template.simple_prompt_template == ""
def test_keeps_existing(self, runner, mocker):
entity = mocker.MagicMock()
entity.app_config.prompt_template.simple_prompt_template = "abc"
result = runner._repack_app_generate_entity(entity)
assert result.app_config.prompt_template.simple_prompt_template == "abc"
# ==========================================================
# update_prompt_message_tool
# ==========================================================
class TestUpdatePromptTool:
def build_param(self, mocker, **kwargs):
p = mocker.MagicMock()
p.form = kwargs.get("form")
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
p.type = mock_type
p.name = kwargs.get("name", "p1")
p.llm_description = "desc"
p.input_schema = kwargs.get("input_schema")
p.options = kwargs.get("options")
p.required = kwargs.get("required", False)
return p
def test_skip_non_llm(self, runner, mocker):
tool = mocker.MagicMock()
param = self.build_param(mocker, form="NOT_LLM")
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_enum_and_required(self, runner, mocker):
option = mocker.MagicMock(value="opt1")
param = self.build_param(
mocker,
form=module.ToolParameter.ToolParameterForm.LLM,
options=[option],
required=True,
)
tool = mocker.MagicMock()
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert "p1" in result.parameters["required"]
def test_skip_file_type_param(self, runner, mocker):
tool = mocker.MagicMock()
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
param.type = module.ToolParameter.ToolParameterType.FILE
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_duplicate_required_not_duplicated(self, runner, mocker):
tool = mocker.MagicMock()
param = self.build_param(
mocker,
form=module.ToolParameter.ToolParameterForm.LLM,
required=True,
)
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["required"].count("p1") == 1
# ==========================================================
# create_agent_thought
# ==========================================================
class TestCreateAgentThought:
def test_with_files(self, runner, mock_db_session, mocker):
mock_thought = mocker.MagicMock(id=10)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"])
assert result == "10"
assert runner.agent_thought_count == 1
def test_without_files(self, runner, mock_db_session, mocker):
mock_thought = mocker.MagicMock(id=11)
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
result = runner.create_agent_thought("m", "msg", "tool", "input", [])
assert result == "11"
# ==========================================================
# save_agent_thought
# ==========================================================
class TestSaveAgentThought:
def setup_agent(self, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1;tool2"
agent.tool_labels = {}
agent.thought = ""
return agent
def test_not_found(self, runner, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError):
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
def test_full_update(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
mock_label = mocker.MagicMock()
mock_label.to_dict.return_value = {"en_US": "label"}
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label)
usage = mocker.MagicMock(
prompt_tokens=1,
prompt_price_unit=Decimal("0.1"),
prompt_unit_price=Decimal("0.1"),
completion_tokens=2,
completion_price_unit=Decimal("0.2"),
completion_unit_price=Decimal("0.2"),
total_tokens=3,
total_price=Decimal("0.3"),
)
runner.save_agent_thought(
"id",
"tool1;tool2",
{"a": 1},
"thought",
{"b": 2},
{"meta": 1},
"answer",
["f1"],
usage,
)
assert agent.answer == "answer"
assert agent.tokens == 3
assert "tool1" in json.loads(agent.tool_labels_str)
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
agent.tool = "unknown_tool"
mock_db_session.scalar.return_value = agent
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
labels = json.loads(agent.tool_labels_str)
assert "unknown_tool" in labels
def test_json_failure_paths(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
bad_obj = MagicMock()
bad_obj.__str__.return_value = "bad"
runner.save_agent_thought(
"id",
None,
bad_obj,
None,
bad_obj,
bad_obj,
None,
[],
None,
)
assert mock_db_session.commit.called
def test_messages_ids_none(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
assert mock_db_session.commit.called
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
agent = self.setup_agent(mocker)
mock_db_session.scalar.return_value = agent
runner.save_agent_thought(
"id",
None,
{"a": 1},
None,
{"b": 2},
None,
None,
[],
None,
)
assert isinstance(agent.tool_input, str)
assert isinstance(agent.observation, str)
# ==========================================================
# organize_agent_user_prompt
# ==========================================================
class TestOrganizeUserPrompt:
def test_no_files(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = []
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_with_files_no_config(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
file_config.image_config = mocker.MagicMock(detail=None)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[])
msg = mocker.MagicMock(id="1", query="hello")
msg.app_model_config.to_dict.return_value = {}
result = runner.organize_agent_user_prompt(msg)
assert result.content == "hello"
# ==========================================================
# organize_agent_history
# ==========================================================
class TestOrganizeHistory:
def test_empty(self, runner, mock_db_session, mocker):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
result = runner.organize_agent_history([])
assert result == []
def test_with_answer_only(self, runner, mock_db_session, mocker):
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
def test_skip_current_message(self, runner, mock_db_session, mocker):
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert result == []
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1",
tool_input="invalid",
observation="invalid",
thought="thinking",
)
msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(tool=";", thought="thinking")
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1",
tool_input=json.dumps({"tool1": {"x": 1}}),
observation=json.dumps({"tool1": "obs"}),
thought="thinking",
)
msg = mocker.MagicMock(
id="m100",
agent_thoughts=[thought],
answer=None,
app_model_config=None,
)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
result = runner.organize_agent_history([])
assert isinstance(result, list)
# ==========================================================
# _convert_tool_to_prompt_message_tool (new coverage)
# ==========================================================
class TestConvertToolToPromptMessageTool:
def test_basic_conversion(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
runtime_param = mocker.MagicMock()
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
runtime_param.name = "param1"
runtime_param.llm_description = "desc"
runtime_param.required = True
runtime_param.input_schema = None
runtime_param.options = None
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
runtime_param.type = mock_type
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
def test_full_conversion_multiple_params(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
# LLM param with input_schema override
param1 = mocker.MagicMock()
param1.form = module.ToolParameter.ToolParameterForm.LLM
param1.name = "p1"
param1.llm_description = "desc"
param1.required = True
param1.input_schema = {"type": "integer"}
param1.options = None
param1.type = mocker.MagicMock()
# SYSTEM_FILES param should be skipped
param2 = mocker.MagicMock()
param2.form = module.ToolParameter.ToolParameterForm.LLM
param2.name = "file_param"
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
# ==========================================================
# _init_prompt_tools additional branches
# ==========================================================
class TestInitPromptToolsExtended:
def test_agent_tool_branch(self, runner, mocker):
agent_tool = mocker.MagicMock(tool_name="agent_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
tools, prompts = runner._init_prompt_tools()
assert "agent_tool" in tools
def test_exception_in_conversion(self, runner, mocker):
agent_tool = mocker.MagicMock(tool_name="bad_tool")
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
tools, prompts = runner._init_prompt_tools()
assert tools == {}
# ==========================================================
# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS)
# ==========================================================
class TestAdditionalCoverage:
def test_update_prompt_with_input_schema(self, runner, mocker):
tool = mocker.MagicMock()
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "p1"
param.required = False
param.llm_description = "desc"
param.options = None
param.input_schema = {"type": "number"}
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
param.type = mock_type
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["p1"]["type"] == "number"
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {"tool1": {"en_US": "existing"}}
agent.thought = ""
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
labels = json.loads(agent.tool_labels_str)
assert labels["tool1"]["en_US"] == "existing"
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
agent.thought = ""
mock_db_session.scalar.return_value = agent
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
assert agent.tool_meta_str == "meta_string"
def test_convert_dataset_retriever_tool(self, runner, mocker):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
param = mocker.MagicMock()
param.name = "query"
param.llm_description = "desc"
param.required = True
ds_tool.get_runtime_parameters.return_value = [param]
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
assert prompt is not None
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
file_config = mocker.MagicMock()
file_config.image_config = mocker.MagicMock(detail=None)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"])
mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock())
mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw))
mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw))
msg = mocker.MagicMock(id="1", query="hello")
msg.app_model_config.to_dict.return_value = {}
result = runner.organize_agent_user_prompt(msg)
assert result is not None
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(tool=None, thought="thinking")
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
result = runner.organize_agent_history([])
assert isinstance(result, list)
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1;tool2",
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
observation=json.dumps({"tool1": "o1", "tool2": "o2"}),
thought="thinking",
)
msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
result = runner.organize_agent_history([])
assert isinstance(result, list)
# ================= Additional Surgical Coverage =================
def test_convert_tool_select_enum_branch(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "select_param"
param.required = True
param.llm_description = "desc"
param.input_schema = None
option1 = mocker.MagicMock(value="A")
option2 = mocker.MagicMock(value="B")
param.options = [option1, option2]
param.type = module.ToolParameter.ToolParameterType.SELECT
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
assert prompt_tool is not None
class TestConvertDatasetRetrieverTool:
def test_required_param_added(self, runner, mocker):
ds_tool = mocker.MagicMock()
ds_tool.entity.identity.name = "ds"
ds_tool.entity.description.llm = "desc"
param = mocker.MagicMock()
param.name = "query"
param.llm_description = "desc"
param.required = True
ds_tool.get_runtime_parameters.return_value = [param]
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
assert prompt is not None
class TestBaseAgentRunnerInit:
def test_init_sets_stream_tool_call_and_files(self, mocker):
session = mocker.MagicMock()
session.query.return_value.where.return_value.count.return_value = 2
mocker.patch.object(module.db, "session", session)
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"])
llm = mocker.MagicMock()
llm.get_model_schema.return_value = mocker.MagicMock(
features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION]
)
model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c")
app_config = mocker.MagicMock()
app_config.app_id = "app1"
app_config.agent = None
app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"})
app_config.additional_features = mocker.MagicMock(show_retrieve_source=True)
app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"])
message = mocker.MagicMock(id="msg1", conversation_id="conv1")
runner = BaseAgentRunner(
tenant_id="tenant",
application_generate_entity=app_generate,
conversation=mocker.MagicMock(),
app_config=app_config,
model_config=mocker.MagicMock(),
config=mocker.MagicMock(),
queue_manager=mocker.MagicMock(),
message=message,
user_id="user",
model_instance=model_instance,
)
assert runner.stream_tool_call is True
assert runner.files == ["file1"]
assert runner.dataset_tools == ["ds_tool"]
assert runner.agent_thought_count == 2
class TestBaseAgentRunnerCoverage:
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
param.form = "NOT_LLM"
param.type = mocker.MagicMock()
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
assert prompt_tool.parameters["properties"] == {}
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
dataset_tool = mocker.MagicMock()
dataset_tool.entity.identity.name = "ds"
runner.dataset_tools = [dataset_tool]
mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock())
tools, prompt_tools = runner._init_prompt_tools()
assert tools["ds"] == dataset_tool
assert len(prompt_tools) == 1
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
tool = mocker.MagicMock()
option1 = mocker.MagicMock(value="A")
option2 = mocker.MagicMock(value="B")
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "select_param"
param.required = False
param.llm_description = "desc"
param.input_schema = None
param.options = [option1, option2]
param.type = module.ToolParameter.ToolParameterType.SELECT
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1"
agent.tool_labels = {}
agent.thought = ""
mock_db_session.scalar.return_value = agent
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
tool_input = {"a": 1}
observation = {"b": 2}
tool_meta = {"c": 3}
real_dumps = json.dumps
def dumps_side_effect(value, *args, **kwargs):
if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False:
raise TypeError("fail")
return real_dumps(value, *args, **kwargs)
mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect)
runner.save_agent_thought(
"id",
"tool1",
tool_input,
None,
observation,
tool_meta,
None,
[],
None,
)
assert isinstance(agent.tool_input, str)
assert isinstance(agent.observation, str)
assert isinstance(agent.tool_meta_str, str)
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
agent = mocker.MagicMock()
agent.tool = "tool1;;"
agent.tool_labels = {}
agent.thought = ""
mock_db_session.scalar.return_value = agent
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
labels = json.loads(agent.tool_labels_str)
assert "" not in labels
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
mocker.patch.object(module, "extract_thread_messages", return_value=[])
system_message = module.SystemPromptMessage(content="sys")
result = runner.organize_agent_history([system_message])
assert system_message in result
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
thought = mocker.MagicMock(
tool="tool1",
tool_input=None,
observation=None,
thought="thinking",
)
msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None)
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
mocker.patch("uuid.uuid4", return_value="uuid")
mocker.patch.object(
runner,
"organize_agent_user_prompt",
return_value=module.UserPromptMessage(content="user"),
)
result = runner.organize_agent_history([])
assert any(isinstance(item, module.ToolPromptMessage) for item in result)

View File

@@ -1,551 +0,0 @@
import json
from unittest.mock import MagicMock
import pytest
from core.agent.cot_agent_runner import CotAgentRunner
from core.agent.entities import AgentScratchpadUnit
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.agent.exc import AgentMaxIterationError
class DummyRunner(CotAgentRunner):
"""Concrete implementation for testing abstract methods."""
def __init__(self, **kwargs):
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
for k, v in kwargs.items():
setattr(self, k, v)
# Minimal required defaults
self.history_prompt_messages = []
self.memory = None
def _organize_prompt_messages(self):
return []
@pytest.fixture
def runner(mocker):
# Prevent BaseAgentRunner __init__ from hitting database
mocker.patch(
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
return_value=[],
)
# Prepare required constructor dependencies for BaseAgentRunner
application_generate_entity = MagicMock()
application_generate_entity.model_conf = MagicMock()
application_generate_entity.model_conf.stop = []
application_generate_entity.model_conf.provider = "openai"
application_generate_entity.model_conf.parameters = {}
application_generate_entity.trace_manager = None
application_generate_entity.invoke_from = "test"
app_config = MagicMock()
app_config.agent = MagicMock()
app_config.agent.max_iteration = 1
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.model_name = "test-model"
model_instance.invoke_llm.return_value = []
model_config = MagicMock()
model_config.model = "test-model"
queue_manager = MagicMock()
message = MagicMock()
runner = DummyRunner(
tenant_id="tenant",
application_generate_entity=application_generate_entity,
conversation=MagicMock(),
app_config=app_config,
model_config=model_config,
config=MagicMock(),
queue_manager=queue_manager,
message=message,
user_id="user",
model_instance=model_instance,
)
# Patch internal methods to isolate behavior
runner._repack_app_generate_entity = MagicMock()
runner._init_prompt_tools = MagicMock(return_value=({}, []))
runner.recalc_llm_max_tokens = MagicMock()
runner.create_agent_thought = MagicMock(return_value="thought-id")
runner.save_agent_thought = MagicMock()
runner.update_prompt_message_tool = MagicMock()
runner.agent_callback = None
runner.memory = None
runner.history_prompt_messages = []
return runner
class TestFillInputs:
@pytest.mark.parametrize(
("instruction", "inputs", "expected"),
[
("Hello {{name}}", {"name": "John"}, "Hello John"),
("No placeholders", {"name": "John"}, "No placeholders"),
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
("{{x}}", {"x": None}, "None"),
("", {"x": "y"}, ""),
],
)
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
assert result == expected
class TestConvertDictToAction:
def test_convert_valid_dict(self, runner):
action_dict = {"action": "test", "action_input": {"a": 1}}
action = runner._convert_dict_to_action(action_dict)
assert action.action_name == "test"
assert action.action_input == {"a": 1}
def test_convert_missing_keys(self, runner):
with pytest.raises(KeyError):
runner._convert_dict_to_action({"invalid": 1})
class TestFormatAssistantMessage:
def test_format_assistant_message_multiple_scratchpads(self, runner):
sp1 = AgentScratchpadUnit(
agent_response="resp1",
thought="thought1",
action_str="action1",
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
observation="obs1",
)
sp2 = AgentScratchpadUnit(
agent_response="final",
thought="",
action_str="",
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
observation=None,
)
result = runner._format_assistant_message([sp1, sp2])
assert "Final Answer:" in result
def test_format_with_final(self, runner):
scratchpad = AgentScratchpadUnit(
agent_response="Done",
thought="",
action_str="",
action=None,
observation=None,
)
# Simulate final state via action name
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
result = runner._format_assistant_message([scratchpad])
assert "Final Answer" in result
def test_format_with_action_and_observation(self, runner):
scratchpad = AgentScratchpadUnit(
agent_response="resp",
thought="thinking",
action_str="action",
action=None,
observation="obs",
)
# Non-final state: provide a non-final action
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
result = runner._format_assistant_message([scratchpad])
assert "Thought:" in result
assert "Action:" in result
assert "Observation:" in result
class TestHandleInvokeAction:
def test_handle_invoke_action_tool_not_present(self, runner):
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
response, meta = runner._handle_invoke_action(action, {}, [])
assert "there is not a tool named" in response
def test_tool_with_json_string_args(self, runner, mocker):
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("result", [], MagicMock(to_dict=lambda: {})),
)
response, meta = runner._handle_invoke_action(action, tool_instances, [])
assert response == "result"
class TestOrganizeHistoricPromptMessages:
def test_empty_history(self, runner, mocker):
mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
return_value=[],
)
result = runner._organize_historic_prompt_messages([])
assert result == []
class TestRun:
def test_run_handles_empty_parser_output(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
results = list(runner.run(message, "query", {}))
assert isinstance(results, list)
def test_run_with_action_and_tool_invocation(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
runner.agent_callback = None
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_respects_max_iteration_boundary(self, runner, mocker):
runner.app_config.agent.max_iteration = 1
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
runner.agent_callback = None
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {"tool": MagicMock()}))
def test_run_basic_flow(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
results = list(runner.run(message, "query", {"name": "John"}))
assert results
def test_run_max_iteration_error(self, runner, mocker):
runner.app_config.agent.max_iteration = 0
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query", {}))
def test_run_increase_usage_aggregation(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
runner.app_config.agent.max_iteration = 2
usage_1 = LLMUsage.empty_usage()
usage_1.prompt_tokens = 1
usage_1.completion_tokens = 1
usage_1.total_tokens = 2
usage_1.prompt_price = 1
usage_1.completion_price = 1
usage_1.total_price = 2
usage_2 = LLMUsage.empty_usage()
usage_2.prompt_tokens = 1
usage_2.completion_tokens = 1
usage_2.total_tokens = 2
usage_2.prompt_price = 1
usage_2.completion_price = 1
usage_2.total_price = 2
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
handle_output = mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
side_effect=[
[action],
[],
],
)
def _handle_side_effect(chunks, usage_dict):
call_index = handle_output.call_count
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
return [action] if call_index == 1 else []
handle_output.side_effect = _handle_side_effect
runner.model_instance.invoke_llm = MagicMock(return_value=[])
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
fake_prompt_tool = MagicMock()
fake_prompt_tool.name = "tool"
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
results = list(runner.run(message, "query", {}))
final_usage = results[-1].delta.usage
assert final_usage is not None
assert final_usage.prompt_tokens == 2
assert final_usage.completion_tokens == 2
assert final_usage.total_tokens == 4
assert final_usage.prompt_price == 2
assert final_usage.completion_price == 2
assert final_usage.total_price == 4
def test_run_when_no_action_branch(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == ""
def test_run_usage_missing_key_branch(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[],
)
runner.model_instance.invoke_llm = MagicMock(return_value=[])
list(runner.run(message, "query", {}))
def test_run_prompt_tool_update_branch(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
# First iteration → action
# Second iteration → no action (empty list)
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
side_effect=[[action], []],
)
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
)
runner.app_config.agent.max_iteration = 5
fake_prompt_tool = MagicMock()
fake_prompt_tool.name = "tool"
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
runner.update_prompt_message_tool = MagicMock()
runner.agent_callback = None
list(runner.run(message, "query", {}))
runner.update_prompt_message_tool.assert_called_once()
def test_historic_with_assistant_and_tool_calls(self, runner):
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
assistant = AssistantPromptMessage(content="thinking")
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
runner.history_prompt_messages = [assistant, tool_msg]
result = runner._organize_historic_prompt_messages([])
assert isinstance(result, list)
def test_historic_final_flush_branch(self, runner):
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
assistant = AssistantPromptMessage(content="final")
runner.history_prompt_messages = [assistant]
result = runner._organize_historic_prompt_messages([])
assert isinstance(result, list)
class TestInitReactState:
def test_init_react_state_resets_state(self, runner, mocker):
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
runner._agent_scratchpad = ["old"]
runner._query = "old"
runner._init_react_state("new-query")
assert runner._query == "new-query"
assert runner._agent_scratchpad == []
assert runner._historic_prompt_messages == ["historic"]
class TestHandleInvokeActionExtended:
def test_tool_with_invalid_json_string_args(self, runner, mocker):
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
tool_instance = MagicMock()
tool_instances = {"tool": tool_instance}
mocker.patch(
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
)
message_file_ids = []
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
assert response == "ok"
assert message_file_ids == ["file1"]
runner.queue_manager.publish.assert_called()
class TestFillInputsEdgeCases:
def test_fill_inputs_with_empty_inputs(self, runner):
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
assert result == "Hello {{x}}"
def test_fill_inputs_with_exception_in_replace(self, runner):
class BadValue:
def __str__(self):
raise Exception("fail")
# Should silently continue on exception
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
assert result == "Hello {{x}}"
class TestOrganizeHistoricPromptMessagesExtended:
def test_user_message_flushes_scratchpad(self, runner, mocker):
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
user_message = UserPromptMessage(content="Hi")
runner.history_prompt_messages = [user_message]
mock_transform = mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
)
mock_transform.return_value.get_prompt.return_value = ["final"]
result = runner._organize_historic_prompt_messages([])
assert result == ["final"]
def test_tool_message_without_scratchpad_raises(self, runner):
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
with pytest.raises(NotImplementedError):
runner._organize_historic_prompt_messages([])
def test_agent_history_transform_invocation(self, runner, mocker):
mock_transform = MagicMock()
mock_transform.get_prompt.return_value = []
mocker.patch(
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
return_value=mock_transform,
)
runner.history_prompt_messages = []
result = runner._organize_historic_prompt_messages([])
assert result == []
class TestRunAdditionalBranches:
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=["thinking"],
)
results = list(runner.run(message, "query", {}))
assert any(hasattr(r, "delta") for r in results)
def test_run_with_final_answer_action_string(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == "done"
def test_run_with_final_answer_action_dict(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
results = list(runner.run(message, "query", {}))
assert json.loads(results[-1].delta.message.content) == {"a": 1}
def test_run_with_string_final_answer(self, runner, mocker):
message = MagicMock()
message.id = "msg-id"
# Remove invalid branch: Pydantic enforces str|dict for action_input
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
mocker.patch(
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
return_value=[action],
)
results = list(runner.run(message, "query", {}))
assert results[-1].delta.message.content == "12345"

View File

@@ -1,215 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
from tests.unit_tests.core.agent.conftest import (
DummyAgentConfig,
DummyAppConfig,
DummyTool,
)
from tests.unit_tests.core.agent.conftest import (
DummyPromptEntity as DummyPrompt,
)
class DummyFileUploadConfig:
def __init__(self, image_config=None):
self.image_config = image_config
class DummyImageConfig:
def __init__(self, detail=None):
self.detail = detail
class DummyGenerateEntity:
def __init__(self, file_upload_config=None):
self.file_upload_config = file_upload_config
class DummyUnit:
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
self._final = final
self.thought = thought
self.action_str = action_str
self.observation = observation
self.agent_response = agent_response
def is_final(self):
return self._final
@pytest.fixture
def runner():
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
runner._instruction = "test_instruction"
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
runner._query = "user query"
runner._agent_scratchpad = []
runner.files = []
runner.application_generate_entity = DummyGenerateEntity()
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
return runner
class TestOrganizeSystemPrompt:
def test_organize_system_prompt_success(self, runner, mocker):
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
mocker.patch(
"core.agent.cot_chat_agent_runner.jsonable_encoder",
return_value=[{"name": "tool1"}, {"name": "tool2"}],
)
result = runner._organize_system_prompt()
assert "test_instruction" in result.content
assert "tool1" in result.content
assert "tool2" in result.content
assert "tool1, tool2" in result.content
def test_organize_system_prompt_missing_agent(self, runner):
runner.app_config = DummyAppConfig(agent=None)
with pytest.raises(AssertionError):
runner._organize_system_prompt()
def test_organize_system_prompt_missing_prompt(self, runner):
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
with pytest.raises(AssertionError):
runner._organize_system_prompt()
class TestOrganizeUserQuery:
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
def test_organize_user_query_no_files(self, runner, files):
runner.files = files
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert result[0].content == "query"
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_content = ImagePromptMessageContent(
url="http://test",
format="png",
mime_type="image/png",
)
mock_to_prompt.return_value = mock_content
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
runner.files = ["file1"]
runner.application_generate_entity = DummyGenerateEntity(None)
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
assert mock_content in result[0].content
mock_to_prompt.assert_called_once_with(
"file1",
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
)
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_content = ImagePromptMessageContent(
url="http://test",
format="png",
mime_type="image/png",
)
mock_to_prompt.return_value = mock_content
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
runner.files = ["file1"]
image_config = DummyImageConfig(detail="high")
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
assert mock_content in result[0].content
mock_to_prompt.assert_called_once_with(
"file1",
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
)
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
runner.files = ["file1"]
runner.application_generate_entity = DummyGenerateEntity(None)
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
class TestOrganizePromptMessages:
def test_no_scratchpad(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
result = runner._organize_prompt_messages()
assert "system" in result
assert "query" in result
runner._organize_historic_prompt_messages.assert_called_once()
def test_with_final_scratchpad(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
unit = DummyUnit(final=True, agent_response="done")
runner._agent_scratchpad = [unit]
result = runner._organize_prompt_messages()
assistant_msgs = [m for m in result if hasattr(m, "content")]
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Final Answer: done" in combined
def test_with_thought_action_observation(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
unit = DummyUnit(
final=False,
thought="thinking",
action_str="action",
observation="observe",
)
runner._agent_scratchpad = [unit]
result = runner._organize_prompt_messages()
assistant_msgs = [m for m in result if hasattr(m, "content")]
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Thought: thinking" in combined
assert "Action: action" in combined
assert "Observation: observe" in combined
def test_multiple_units_mixed(self, runner, mocker):
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
runner._organize_system_prompt = MagicMock(return_value="system")
runner._organize_user_query = MagicMock(return_value=["query"])
units = [
DummyUnit(final=False, thought="t1"),
DummyUnit(final=True, agent_response="done"),
]
runner._agent_scratchpad = units
result = runner._organize_prompt_messages()
assistant_msgs = [m for m in result if hasattr(m, "content")]
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
assert "Thought: t1" in combined
assert "Final Answer: done" in combined

View File

@@ -1,234 +0,0 @@
import json
import pytest
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
# -----------------------------
# Fixtures
# -----------------------------
@pytest.fixture
def runner(mocker, dummy_tool_factory):
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
runner._instruction = "Test instruction"
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
runner._query = "What is Python?"
runner._agent_scratchpad = []
mocker.patch(
"core.agent.cot_completion_agent_runner.jsonable_encoder",
side_effect=lambda tools: [{"name": t.name} for t in tools],
)
return runner
# ======================================================
# _organize_instruction_prompt Tests
# ======================================================
class TestOrganizeInstructionPrompt:
def test_success_all_placeholders(
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
):
template = (
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
)
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
result = runner._organize_instruction_prompt()
assert "Test instruction" in result
assert "toolA" in result
assert "toolB" in result
tools_payload = json.loads(result.split(" | ")[1])
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
def test_agent_none_raises(self, runner, dummy_app_config_factory):
runner.app_config = dummy_app_config_factory(agent=None)
with pytest.raises(ValueError, match="Agent configuration is not set"):
runner._organize_instruction_prompt()
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
with pytest.raises(ValueError, match="prompt entity is not set"):
runner._organize_instruction_prompt()
# ======================================================
# _organize_historic_prompt Tests
# ======================================================
class TestOrganizeHistoricPrompt:
def test_with_user_and_assistant_string(self, runner, mocker):
user_msg = UserPromptMessage(content="Hello")
assistant_msg = AssistantPromptMessage(content="Hi there")
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[user_msg, assistant_msg],
)
result = runner._organize_historic_prompt()
assert "Question: Hello" in result
assert "Hi there" in result
def test_assistant_list_with_text_content(self, runner, mocker):
text_content = TextPromptMessageContent(data="Partial answer")
assistant_msg = AssistantPromptMessage(content=[text_content])
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[assistant_msg],
)
result = runner._organize_historic_prompt()
assert "Partial answer" in result
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
assistant_msg = AssistantPromptMessage(content=[non_text_content])
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[assistant_msg],
)
result = runner._organize_historic_prompt()
assert result == ""
def test_empty_history(self, runner, mocker):
mocker.patch.object(
runner,
"_organize_historic_prompt_messages",
return_value=[],
)
result = runner._organize_historic_prompt()
assert result == ""
# ======================================================
# _organize_prompt_messages Tests
# ======================================================
class TestOrganizePromptMessages:
def test_full_flow_with_scratchpad(
self,
runner,
mocker,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
dummy_scratchpad_unit_factory,
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
runner._agent_scratchpad = [
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
]
result = runner._organize_prompt_messages()
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
content = result[0].content
assert "History" in content
assert "Thought: Thinking" in content
assert "Action: Act" in content
assert "Observation: Obs" in content
assert "Final Answer: Done" in content
assert "Question: What is Python?" in content
def test_no_scratchpad(
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
runner._agent_scratchpad = None
result = runner._organize_prompt_messages()
assert "Question: What is Python?" in result[0].content
@pytest.mark.parametrize(
("thought", "action", "observation"),
[
("T", None, None),
("T", "A", None),
("T", None, "O"),
],
)
def test_partial_scratchpad_units(
self,
runner,
mocker,
thought,
action,
observation,
dummy_app_config_factory,
dummy_agent_config_factory,
dummy_prompt_entity_factory,
dummy_scratchpad_unit_factory,
):
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
runner.app_config = dummy_app_config_factory(
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
)
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
runner._agent_scratchpad = [
dummy_scratchpad_unit_factory(
final=False,
thought=thought,
action_str=action,
observation=observation,
)
]
result = runner._organize_prompt_messages()
content = result[0].content
assert "Thought:" in content
if action:
assert "Action:" in content
if observation:
assert "Observation:" in content

View File

@@ -1,452 +0,0 @@
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueMessageFileEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
DocumentPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from dify_graph.nodes.agent.exc import AgentMaxIterationError
# ==============================
# Dummy Helper Classes
# ==============================
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
usage = LLMUsage.empty_usage()
usage.prompt_tokens = pt
usage.completion_tokens = ct
usage.total_tokens = tt
usage.prompt_price = 0
usage.completion_price = 0
usage.total_price = 0
return usage
class DummyMessage:
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
self.content: str | None = content
self.tool_calls: list[Any] = tool_calls or []
class DummyDelta:
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
self.message: DummyMessage | None = message
self.usage: LLMUsage | None = usage
class DummyChunk:
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
class DummyResult:
def __init__(
self,
message: DummyMessage | None = None,
usage: LLMUsage | None = None,
prompt_messages: list[DummyMessage] | None = None,
):
self.message: DummyMessage | None = message
self.usage: LLMUsage | None = usage
self.prompt_messages: list[DummyMessage] = prompt_messages or []
self.system_fingerprint: str = ""
# ==============================
# Fixtures
# ==============================
@pytest.fixture
def runner(mocker):
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
mocker.patch(
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
return_value=None,
)
# Patch streaming chunk models to avoid validation on dummy message objects
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
app_config = MagicMock()
app_config.agent = MagicMock(max_iteration=2)
app_config.prompt_template = MagicMock(simple_prompt_template="system")
application_generate_entity = MagicMock()
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
application_generate_entity.trace_manager = MagicMock()
application_generate_entity.invoke_from = "test"
application_generate_entity.app_config = MagicMock(app_id="app")
application_generate_entity.file_upload_config = None
queue_manager = MagicMock()
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.model_name = "test-model"
message = MagicMock(id="msg1")
conversation = MagicMock(id="conv1")
runner = FunctionCallAgentRunner(
tenant_id="tenant",
application_generate_entity=application_generate_entity,
conversation=conversation,
app_config=app_config,
model_config=MagicMock(),
config=MagicMock(),
queue_manager=queue_manager,
message=message,
user_id="user",
model_instance=model_instance,
)
# Manually inject required attributes normally set by BaseAgentRunner
runner.tenant_id = "tenant"
runner.application_generate_entity = application_generate_entity
runner.conversation = conversation
runner.app_config = app_config
runner.model_config = MagicMock()
runner.config = MagicMock()
runner.queue_manager = queue_manager
runner.message = message
runner.user_id = "user"
runner.model_instance = model_instance
runner.stream_tool_call = False
runner.memory = None
runner.history_prompt_messages = []
runner._current_thoughts = []
runner.files = []
runner.agent_callback = MagicMock()
runner._init_prompt_tools = MagicMock(return_value=({}, []))
runner.create_agent_thought = MagicMock(return_value="thought1")
runner.save_agent_thought = MagicMock()
runner.recalc_llm_max_tokens = MagicMock()
runner.update_prompt_message_tool = MagicMock()
return runner
# ==============================
# Tool Call Checks
# ==============================
class TestToolCallChecks:
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
def test_check_tool_calls(self, runner, tool_calls, expected):
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
assert runner.check_tool_calls(chunk) is expected
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
assert runner.check_blocking_tool_calls(result) is expected
# ==============================
# Extract Tool Calls
# ==============================
class TestExtractToolCalls:
def test_extract_tool_calls_with_valid_json(self, runner):
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
calls = runner.extract_tool_calls(chunk)
assert calls == [("1", "tool", {"a": 1})]
def test_extract_tool_calls_empty_arguments(self, runner):
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = ""
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
calls = runner.extract_tool_calls(chunk)
assert calls == [("1", "tool", {})]
def test_extract_blocking_tool_calls(self, runner):
tool_call = MagicMock()
tool_call.id = "2"
tool_call.function.name = "block"
tool_call.function.arguments = json.dumps({"x": 2})
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
calls = runner.extract_blocking_tool_calls(result)
assert calls == [("2", "block", {"x": 2})]
# ==============================
# System Message Initialization
# ==============================
class TestInitSystemMessage:
def test_init_system_message_empty_prompt_messages(self, runner):
result = runner._init_system_message("system", [])
assert len(result) == 1
def test_init_system_message_insert_at_start(self, runner):
msgs = [MagicMock()]
result = runner._init_system_message("system", msgs)
assert result[0].content == "system"
def test_init_system_message_no_template(self, runner):
result = runner._init_system_message("", [])
assert result == []
# ==============================
# Organize User Query
# ==============================
class TestOrganizeUserQuery:
def test_without_files(self, runner):
result = runner._organize_user_query("query", [])
assert len(result) == 1
def test_with_none_query(self, runner):
result = runner._organize_user_query(None, [])
assert len(result) == 1
def test_with_files_uses_image_detail_config(self, runner, mocker):
file_content = TextPromptMessageContent(data="file-content")
mock_to_prompt = mocker.patch(
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
return_value=file_content,
)
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
runner.files = ["file1"]
result = runner._organize_user_query("query", [])
assert len(result) == 1
assert isinstance(result[0].content, list)
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
# ==============================
# Clear User Prompt Images
# ==============================
class TestClearUserPromptImageMessages:
def test_clear_text_and_image_content(self, runner):
text = MagicMock()
text.type = "text"
text.data = "hello"
image = MagicMock()
image.type = "image"
image.data = "img"
user_msg = MagicMock()
user_msg.__class__.__name__ = "UserPromptMessage"
user_msg.content = [text, image]
result = runner._clear_user_prompt_image_messages([user_msg])
assert isinstance(result, list)
def test_clear_includes_file_placeholder(self, runner):
text = TextPromptMessageContent(data="hello")
image = ImagePromptMessageContent(format="url", mime_type="image/png")
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
user_msg = UserPromptMessage(content=[text, image, document])
result = runner._clear_user_prompt_image_messages([user_msg])
assert result[0].content == "hello\n[image]\n[file]"
# ==============================
# Run Method Tests
# ==============================
class TestRunMethod:
def test_run_non_streaming_no_tool_calls(self, runner):
message = MagicMock(id="m1")
dummy_message = DummyMessage(content="hello")
result = DummyResult(message=dummy_message, usage=build_usage())
runner.model_instance.invoke_llm.return_value = result
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
runner.queue_manager.publish.assert_called()
queue_calls = runner.queue_manager.publish.call_args_list
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
def test_run_streaming_branch(self, runner):
message = MagicMock(id="m1")
runner.stream_tool_call = True
content = [TextPromptMessageContent(data="hi")]
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
def generator():
yield chunk
runner.model_instance.invoke_llm.return_value = generator()
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
def test_run_streaming_tool_calls_list_content(self, runner):
message = MagicMock(id="m1")
runner.stream_tool_call = True
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
content = [TextPromptMessageContent(data="hi")]
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
def generator():
yield chunk
final_message = DummyMessage(content="done", tool_calls=[])
final_result = DummyResult(message=final_message, usage=build_usage())
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_non_streaming_list_content(self, runner):
message = MagicMock(id="m1")
content = [TextPromptMessageContent(data="hi")]
dummy_message = DummyMessage(content=content)
result = DummyResult(message=dummy_message, usage=build_usage())
runner.model_instance.invoke_llm.return_value = result
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
message = MagicMock(id="m1")
runner.stream_tool_call = True
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
def generator():
yield chunk
runner.model_instance.invoke_llm.return_value = generator()
real_dumps = json.dumps
def flaky_dumps(obj, *args, **kwargs):
if kwargs.get("ensure_ascii") is False:
return real_dumps(obj, *args, **kwargs)
raise TypeError("boom")
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
outputs = list(runner.run(message, "query"))
assert len(outputs) == 1
def test_run_with_missing_tool_instance(self, runner):
message = MagicMock(id="m1")
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "missing"
tool_call.function.arguments = json.dumps({})
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
result = DummyResult(message=dummy_message, usage=build_usage())
final_message = DummyMessage(content="done", tool_calls=[])
final_result = DummyResult(message=final_message, usage=build_usage())
runner.model_instance.invoke_llm.side_effect = [result, final_result]
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
def test_run_with_tool_instance_and_files(self, runner, mocker):
message = MagicMock(id="m1")
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = json.dumps({"a": 1})
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
result = DummyResult(message=dummy_message, usage=build_usage())
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
runner.model_instance.invoke_llm.side_effect = [result, final_result]
tool_instance = MagicMock()
prompt_tool = MagicMock()
prompt_tool.name = "tool"
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
tool_invoke_meta = MagicMock()
tool_invoke_meta.to_dict.return_value = {"ok": True}
mocker.patch(
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
return_value=("ok", ["file1"], tool_invoke_meta),
)
outputs = list(runner.run(message, "query"))
assert len(outputs) >= 1
assert any(
isinstance(call.args[0], QueueMessageFileEvent)
and call.args[0].message_file_id == "file1"
and call.args[1] == PublishFrom.APPLICATION_MANAGER
for call in runner.queue_manager.publish.call_args_list
)
def test_run_max_iteration_error(self, runner):
runner.app_config.agent.max_iteration = 0
message = MagicMock(id="m1")
tool_call = MagicMock()
tool_call.id = "1"
tool_call.function.name = "tool"
tool_call.function.arguments = "{}"
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
result = DummyResult(message=dummy_message, usage=build_usage())
runner.model_instance.invoke_llm.return_value = result
with pytest.raises(AgentMaxIterationError):
list(runner.run(message, "query"))

View File

@@ -1,324 +0,0 @@
"""Unit tests for core.agent.plugin_entities.
Covers entities such as AgentFeature, AgentProviderEntityWithPlugin,
AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter,
AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on
Pydantic ValidationError behavior and pytest fixtures for validation and
mocking; ensure entity invariants and validation rules remain stable.
"""
import pytest
from pydantic import ValidationError
from core.agent.plugin_entities import (
AgentFeature,
AgentProviderEntityWithPlugin,
AgentStrategyEntity,
AgentStrategyIdentity,
AgentStrategyParameter,
AgentStrategyProviderEntity,
AgentStrategyProviderIdentity,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
# =========================================================
# Fixtures
# =========================================================
@pytest.fixture
def mock_identity(mocker):
return mocker.MagicMock(spec=AgentStrategyIdentity)
@pytest.fixture
def mock_provider_identity(mocker):
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
# =========================================================
# AgentStrategyParameterType Tests
# =========================================================
class TestAgentStrategyParameterType:
@pytest.mark.parametrize(
"enum_member",
list(AgentStrategyParameter.AgentStrategyParameterType),
)
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.as_normal_type",
return_value="normalized",
)
result = enum_member.as_normal_type()
mock_func.assert_called_once_with(enum_member)
assert result == "normalized"
def test_as_normal_type_propagates_exception(self, mocker) -> None:
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
mocker.patch(
"core.agent.plugin_entities.as_normal_type",
side_effect=RuntimeError("boom"),
)
with pytest.raises(RuntimeError):
enum_member.as_normal_type()
@pytest.mark.parametrize(
("enum_member", "value"),
[
(AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"),
(AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10),
(AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True),
(AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}),
(AgentStrategyParameter.AgentStrategyParameterType.STRING, None),
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
],
)
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.cast_parameter_value",
return_value="casted",
)
result = enum_member.cast_value(value)
mock_func.assert_called_once_with(enum_member, value)
assert result == "casted"
def test_cast_value_propagates_exception(self, mocker) -> None:
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
mocker.patch(
"core.agent.plugin_entities.cast_parameter_value",
side_effect=ValueError("invalid"),
)
with pytest.raises(ValueError):
enum_member.cast_value("bad")
# =========================================================
# AgentStrategyParameter Tests
# =========================================================
class TestAgentStrategyParameter:
def test_valid_creation_minimal(self) -> None:
# bypass base PluginParameter required fields using model_construct
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
help=None,
)
assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING
assert param.help is None
def test_valid_creation_with_help(self) -> None:
help_obj = I18nObject(en_US="test")
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
help=help_obj,
)
assert param.help == help_obj
@pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}])
def test_invalid_type_raises_validation_error(self, invalid_type) -> None:
with pytest.raises(ValidationError) as exc_info:
AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y"))
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
mock_func = mocker.patch(
"core.agent.plugin_entities.init_frontend_parameter",
return_value="frontend",
)
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
)
result = param.init_frontend_parameter("value")
mock_func.assert_called_once_with(param, param.type, "value")
assert result == "frontend"
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
mocker.patch(
"core.agent.plugin_entities.init_frontend_parameter",
side_effect=RuntimeError("error"),
)
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
)
with pytest.raises(RuntimeError):
param.init_frontend_parameter("value")
# =========================================================
# AgentStrategyProviderEntity Tests
# =========================================================
class TestAgentStrategyProviderEntity:
def test_creation_with_plugin_id(self, mock_provider_identity) -> None:
entity = AgentStrategyProviderEntity(
identity=mock_provider_identity,
plugin_id="plugin-123",
)
assert entity.plugin_id == "plugin-123"
def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None:
entity = AgentStrategyProviderEntity(
identity=mock_provider_identity,
plugin_id="",
)
assert entity.plugin_id == ""
def test_creation_without_plugin_id(self, mock_provider_identity) -> None:
entity = AgentStrategyProviderEntity(identity=mock_provider_identity)
assert entity.plugin_id is None
def test_invalid_identity_raises(self) -> None:
with pytest.raises(ValidationError):
AgentStrategyProviderEntity(identity="invalid")
# =========================================================
# AgentStrategyEntity Tests
# =========================================================
class TestAgentStrategyEntity:
def test_parameters_default_empty(self, mock_identity) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
)
assert entity.parameters == []
def test_parameters_none_converted_to_empty(self, mock_identity) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters=None,
)
assert entity.parameters == []
def test_parameters_preserved(self, mock_identity) -> None:
param = AgentStrategyParameter.model_construct(
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
name="test",
label="label",
)
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters=[param],
)
assert entity.parameters == [param]
def test_invalid_parameters_type_raises(self, mock_identity) -> None:
with pytest.raises(ValidationError):
AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters="invalid",
)
@pytest.mark.parametrize(
"features",
[
None,
[],
[AgentFeature.HISTORY_MESSAGES],
],
)
def test_features_valid(self, mock_identity, features) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
features=features,
)
assert entity.features == features
def test_invalid_features_type_raises(self, mock_identity) -> None:
with pytest.raises(ValidationError):
AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
features="invalid",
)
def test_output_schema_and_meta_version(self, mock_identity) -> None:
entity = AgentStrategyEntity(
identity=mock_identity,
description=I18nObject(en_US="test"),
output_schema={"type": "object"},
meta_version="v1",
)
assert entity.output_schema == {"type": "object"}
assert entity.meta_version == "v1"
def test_missing_required_fields_raise(self, mock_identity) -> None:
with pytest.raises(ValidationError):
AgentStrategyEntity(identity=mock_identity)
# =========================================================
# AgentProviderEntityWithPlugin Tests
# =========================================================
class TestAgentProviderEntityWithPlugin:
def test_default_strategies_empty(self, mock_provider_identity) -> None:
entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity)
assert entity.strategies == []
def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None:
strategy = AgentStrategyEntity.model_construct(
identity=mock_identity,
description=I18nObject(en_US="test"),
parameters=[],
)
entity = AgentProviderEntityWithPlugin(
identity=mock_provider_identity,
strategies=[strategy],
)
assert entity.strategies == [strategy]
def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None:
with pytest.raises(ValidationError):
AgentProviderEntityWithPlugin(
identity=mock_provider_identity,
strategies="invalid",
)
# =========================================================
# Inheritance Smoke Tests
# =========================================================
class TestInheritanceBehavior:
def test_agent_strategy_identity_inherits(self) -> None:
assert issubclass(AgentStrategyIdentity, ToolIdentity)
def test_agent_strategy_provider_identity_inherits(self) -> None:
assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)

View File

@@ -1,75 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from models.model import AppMode
class TestAdvancedChatAppConfigManager:
def test_get_app_config(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value)
workflow = SimpleNamespace(id="wf-1", features_dict={})
with (
patch(
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
return_value=None,
),
patch(
"core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert",
return_value=[],
),
):
app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow)
assert app_config.workflow_id == "wf-1"
assert app_config.app_mode == AppMode.ADVANCED_CHAT
def test_config_validate_filters_keys(self):
def _add_key(key, value):
def _inner(*args, **kwargs):
config = kwargs.get("config") if kwargs else args[-1]
config = {**config, key: value}
return config, [key]
return _inner
with (
patch(
"core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=_add_key("file_upload", 1),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
side_effect=_add_key("opening_statement", 2),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
side_effect=_add_key("suggested_questions_after_answer", 3),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
side_effect=_add_key("speech_to_text", 4),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=_add_key("text_to_speech", 5),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
side_effect=_add_key("retriever_resource", 6),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=_add_key("sensitive_word_avoidance", 7),
),
):
filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={})
assert filtered["file_upload"] == 1
assert filtered["opening_statement"] == 2
assert filtered["suggested_questions_after_answer"] == 3
assert filtered["speech_to_text"] == 4
assert filtered["text_to_speech"] == 5
assert filtered["retriever_resource"] == 6
assert filtered["sensitive_word_avoidance"] == 7

View File

@@ -1,96 +0,0 @@
from collections.abc import Generator
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
)
from dify_graph.enums import WorkflowNodeExecutionStatus
class TestAdvancedChatGenerateResponseConverter:
def test_blocking_simple_response_metadata(self):
data = ChatbotAppBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
answer="hi",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
)
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
def test_stream_simple_response_includes_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeStartStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
created_at=1,
),
)
node_finish = NodeFinishStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeFinishStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
elapsed_time=0.1,
created_at=1,
finished_at=2,
),
)
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=PingStreamResponse(task_id="t1"),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=node_start,
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=node_finish,
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
)
converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert converted[0] == "ping"
assert converted[1]["event"] == "node_started"
assert converted[2]["event"] == "node_finished"
assert converted[3]["event"] == "error"

View File

@@ -1,600 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from dify_graph.enums import NodeType
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from models.enums import MessageStatus
from models.model import AppMode, EndUser
def _make_pipeline():
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.ADVANCED_CHAT,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
query="hello",
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_run_id="run-id",
)
message = SimpleNamespace(
id="message-id",
query="hello",
created_at=datetime.utcnow(),
status=MessageStatus.NORMAL,
answer="",
)
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
conversation=conversation,
message=message,
user=user,
stream=False,
dialogue_count=1,
draft_var_saver_factory=lambda **kwargs: None,
)
return pipeline
class TestAdvancedChatGenerateTaskPipeline:
def test_ensure_workflow_initialized_raises(self):
pipeline = _make_pipeline()
with pytest.raises(ValueError, match="workflow run not initialized"):
pipeline._ensure_workflow_initialized()
def test_to_blocking_response_returns_message_end(self):
pipeline = _make_pipeline()
pipeline._task_state.answer = "done"
def _gen():
yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"})
response = pipeline._to_blocking_response(_gen())
assert response.data.answer == "done"
assert response.data.metadata == {"k": "v"}
def test_handle_text_chunk_event_updates_state(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager = SimpleNamespace(
message_to_stream_response=lambda **kwargs: MessageEndStreamResponse(
task_id="task", id="message-id", metadata={}
)
)
event = SimpleNamespace(text="hi", from_variable_selector=None)
responses = list(pipeline._handle_text_chunk_event(event))
assert pipeline._task_state.answer == "hi"
assert responses
def test_listen_audio_msg_returns_audio_stream(self):
pipeline = _make_pipeline()
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
assert isinstance(response, MessageAudioStreamResponse)
def test_handle_ping_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
assert isinstance(responses[0], PingStreamResponse)
def test_handle_error_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
@contextmanager
def _fake_session():
yield SimpleNamespace()
pipeline._database_session = _fake_session
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
assert isinstance(responses[0], ValueError)
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
assert pipeline._workflow_run_id == "run-id"
assert responses == ["started"]
def test_message_end_to_stream_response_strips_annotation_reply(self):
pipeline = _make_pipeline()
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
id="ann",
account=AnnotationReplyAccount(id="acc", name="acc"),
)
response = pipeline._message_end_to_stream_response()
assert "annotation_reply" not in response.metadata
def test_handle_output_moderation_chunk_publishes_stop(self):
pipeline = _make_pipeline()
events: list[object] = []
class _Moderation:
def should_direct_output(self):
return True
def get_final_output(self):
return "final"
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
pipeline._base_task_pipeline.queue_manager = SimpleNamespace(
publish=lambda event, pub_from: events.append(event)
)
result = pipeline._handle_output_moderation_chunk("ignored")
assert result is True
assert pipeline._task_state.answer == "final"
assert any(isinstance(event, QueueTextChunkEvent) for event in events)
assert any(isinstance(event, QueueStopEvent) for event in events)
def test_handle_node_succeeded_event_records_files(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [
{"type": "file", "transfer_method": "local"}
]
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
pipeline._save_output_for_event = lambda event, node_execution_id: None
event = SimpleNamespace(
node_type=NodeType.ANSWER,
outputs={"k": "v"},
node_execution_id="exec",
node_id="node",
)
responses = list(pipeline._handle_node_succeeded_event(event))
assert responses == ["done"]
assert pipeline._recorded_files
def test_iteration_and_loop_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
"iter_start"
)
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
"iter_done"
)
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
iter_start = QueueIterationStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
iter_next = QueueIterationNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
iter_done = QueueIterationCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_start = QueueLoopStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_next = QueueLoopNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
loop_done = QueueLoopCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"]
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"]
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"]
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"]
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
def test_workflow_finish_handlers(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"]
pipeline._persist_human_input_extra_content = lambda **kwargs: None
pipeline._save_message = lambda **kwargs: None
pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id")
@contextmanager
def _fake_session():
yield SimpleNamespace(scalar=lambda *args, **kwargs: None)
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={})))
assert len(succeeded_responses) == 2
assert isinstance(succeeded_responses[0], MessageEndStreamResponse)
assert succeeded_responses[1] == "finish"
partial_success_responses = list(
pipeline._handle_workflow_partial_success_event(
QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
)
)
assert len(partial_success_responses) == 2
assert isinstance(partial_success_responses[0], MessageEndStreamResponse)
assert partial_success_responses[1] == "finish"
assert (
list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0]
== "finish"
)
assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [
"pause"
]
def test_node_failure_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish"
pipeline._save_output_for_event = lambda event, node_execution_id: None
failed_event = QueueNodeFailedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
exc_event = QueueNodeExceptionEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"]
assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"]
def test_handle_text_chunk_event_tracks_streaming_metrics(self):
pipeline = _make_pipeline()
published: list[object] = []
class _Publisher:
def publish(self, message):
published.append(message)
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
queue_message = SimpleNamespace(event=event)
responses = list(
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
)
assert responses == ["chunk"]
assert pipeline._task_state.is_streaming_response is True
assert pipeline._task_state.first_token_time is not None
assert pipeline._task_state.last_token_time is not None
assert pipeline._task_state.answer == "hi"
assert published == [queue_message]
def test_handle_output_moderation_chunk_appends_token(self):
pipeline = _make_pipeline()
seen: list[str] = []
class _Moderation:
def should_direct_output(self):
return False
def append_new_token(self, text):
seen.append(text)
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
result = pipeline._handle_output_moderation_chunk("token")
assert result is False
assert seen == ["token"]
def test_handle_retriever_and_annotation_events(self):
pipeline = _make_pipeline()
calls = {"retriever": 0, "annotation": 0}
def _hit_retriever(event):
calls["retriever"] += 1
def _hit_annotation(event):
calls["annotation"] += 1
pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever
pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation
retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[])
annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann")
assert list(pipeline._handle_retriever_resources_event(retriever_event)) == []
assert list(pipeline._handle_annotation_reply_event(annotation_event)) == []
assert calls == {"retriever": 1, "annotation": 1}
def test_handle_message_replace_event(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
event = QueueMessageReplaceEvent(
text="new",
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
assert list(pipeline._handle_message_replace_event(event)) == ["replace"]
def test_handle_human_input_events(self):
pipeline = _make_pipeline()
persisted: list[str] = []
pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved")
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
filled_event = QueueHumanInputFormFilledEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="title",
rendered_content="content",
action_id="action",
action_text="action",
)
timeout_event = QueueHumanInputFormTimeoutEvent(
node_id="node",
node_type=NodeType.LLM,
node_title="title",
expiration_time=datetime.utcnow(),
)
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
assert persisted == ["saved"]
def test_save_message_strips_markdown_and_sets_usage(self):
pipeline = _make_pipeline()
pipeline._recorded_files = [
{
"type": "image",
"transfer_method": "remote",
"remote_url": "http://example.com/file.png",
"related_id": "file-id",
}
]
pipeline._task_state.answer = "![img](url) hello"
pipeline._task_state.is_streaming_response = True
pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1
pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2
message = SimpleNamespace(
id="message-id",
status=MessageStatus.PAUSED,
answer="",
updated_at=None,
provider_response_latency=None,
message_tokens=None,
message_unit_price=None,
message_price_unit=None,
answer_tokens=None,
answer_unit_price=None,
answer_price_unit=None,
total_price=None,
currency=None,
message_metadata=None,
invoke_from=InvokeFrom.WEB_APP,
from_account_id=None,
from_end_user_id="end-user",
)
class _Session:
def scalar(self, *args, **kwargs):
return message
def add_all(self, items):
self.items = items
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state)
assert message.status == MessageStatus.NORMAL
assert message.answer == "hello"
assert message.message_metadata
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._message_end_to_stream_response = lambda: "end"
saved: list[str] = []
def _save_message(**kwargs):
saved.append("saved")
pipeline._save_message = _save_message
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)))
assert responses == ["end"]
assert saved == ["saved"]
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
pipeline._message_end_to_stream_response = lambda: "end"
saved: list[str] = []
def _save_message(**kwargs):
saved.append("saved")
pipeline._save_message = _save_message
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent()))
assert responses == ["replace", "end"]
assert saved == ["saved"]
def test_dispatch_event_handles_node_exception(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
pipeline._save_output_for_event = lambda *args, **kwargs: None
event = QueueNodeExceptionEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
assert list(pipeline._dispatch_event(event)) == ["failed"]

View File

@@ -1,302 +0,0 @@
import uuid
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.agent_chat.app_config_manager import (
AgentChatAppConfigManager,
)
from core.entities.agent_entities import PlanningStrategy
class TestAgentChatAppConfigManagerGetAppConfig:
def test_get_app_config_override_config(self, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"ignored": True}
override_config = {"model": {"provider": "p"}}
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
return_value=("variables", "external"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
)
result = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=None,
override_config_dict=override_config,
)
assert result.app_model_config_dict == override_config
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
assert result.variables == "variables"
assert result.external_data_variables == "external"
def test_get_app_config_conversation_specific(self, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
conversation = mocker.MagicMock()
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
return_value=("variables", "external"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
)
result = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=None,
)
assert result.app_model_config_dict == app_model_config.to_dict.return_value
assert result.app_model_config_from.value == "conversation-specific-config"
def test_get_app_config_latest_config(self, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
return_value=("variables", "external"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
)
result = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=None,
override_config_dict=None,
)
assert result.app_model_config_from.value == "app-latest-config"
class TestAgentChatAppConfigManagerConfigValidate:
def test_config_validate_filters_related_keys(self, mocker):
config = {
"model": {},
"user_input_form": {},
"file_upload": {},
"prompt_template": {},
"agent_mode": {},
"opening_statement": {},
"suggested_questions_after_answer": {},
"speech_to_text": {},
"text_to_speech": {},
"retriever_resource": {},
"dataset": {},
"moderation": {},
"extra": "value",
}
def return_with_key(key):
return config, [key]
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("model"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("file_upload"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
side_effect=lambda app_mode, cfg: return_with_key("prompt_template"),
)
mocker.patch.object(
AgentChatAppConfigManager,
"validate_agent_mode_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("opening_statement"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("speech_to_text"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("text_to_speech"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("retriever_resource"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("moderation"),
)
filtered = AgentChatAppConfigManager.config_validate("tenant", config)
assert set(filtered.keys()) == {
"model",
"user_input_form",
"file_upload",
"prompt_template",
"agent_mode",
"opening_statement",
"suggested_questions_after_answer",
"speech_to_text",
"text_to_speech",
"retriever_resource",
"dataset",
"moderation",
}
assert "extra" not in filtered
class TestValidateAgentModeAndSetDefaults:
def test_defaults_when_missing(self):
config = {}
updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
assert "agent_mode" in updated
assert updated["agent_mode"]["enabled"] is False
assert updated["agent_mode"]["tools"] == []
assert keys == ["agent_mode"]
@pytest.mark.parametrize(
"agent_mode",
["invalid", 123],
)
def test_agent_mode_type_validation(self, agent_mode):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode})
def test_agent_mode_empty_list_defaults(self):
config = {"agent_mode": []}
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
assert updated["agent_mode"]["enabled"] is False
assert updated["agent_mode"]["tools"] == []
def test_enabled_must_be_bool(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}})
def test_strategy_must_be_valid(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}}
)
def test_tools_must_be_list(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}}
)
def test_old_tool_dataset_requires_id(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}}
)
def test_old_tool_dataset_id_must_be_uuid(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant",
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
)
def test_old_tool_dataset_id_not_exists(self, mocker):
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
return_value=False,
)
dataset_id = str(uuid.uuid4())
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant",
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}},
)
def test_old_tool_enabled_must_be_bool(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant",
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}},
)
@pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"])
def test_new_style_tool_requires_fields(self, missing_key):
tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"}
tool.pop(missing_key, None)
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
)
def test_valid_old_and_new_style_tools(self, mocker):
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
return_value=True,
)
dataset_id = str(uuid.uuid4())
config = {
"agent_mode": {
"enabled": True,
"strategy": PlanningStrategy.ROUTER.value,
"tools": [
{"dataset": {"id": dataset_id}},
{
"provider_type": "builtin",
"provider_id": "p1",
"tool_name": "tool",
"tool_parameters": {},
},
],
}
}
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False
assert updated["agent_mode"]["tools"][1]["enabled"] is False

View File

@@ -1,296 +0,0 @@
import contextlib
import pytest
from pydantic import ValidationError
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
class DummyAccount:
def __init__(self, user_id):
self.id = user_id
self.session_id = f"session-{user_id}"
@pytest.fixture
def generator(mocker):
gen = AgentChatAppGenerator()
mocker.patch(
"core.app.apps.agent_chat.app_generator.current_app",
new=mocker.MagicMock(_get_current_object=mocker.MagicMock()),
)
mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx")
return gen
class TestAgentChatAppGeneratorGenerate:
def test_generate_rejects_blocking_mode(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
def test_generate_requires_query(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
def test_generate_rejects_non_string_query(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(
app_model=app_model,
user=user,
args={"query": 123, "inputs": {}},
invoke_from=mocker.MagicMock(),
)
def test_generate_override_requires_debugger(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(
app_model=app_model,
user=user,
args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}},
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_success_with_debugger_override(self, generator, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
user = DummyAccount("user")
invoke_from = InvokeFrom.DEBUGGER
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
generator._init_generate_records = mocker.MagicMock(
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
)
generator._handle_response = mocker.MagicMock(return_value="response")
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate",
return_value={"validated": True},
)
app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[])
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
return_value=app_config,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
return_value=["file-obj"],
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.ConversationService.get_conversation",
return_value=mocker.MagicMock(id="conv"),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
return_value=mocker.MagicMock(),
)
queue_manager = mocker.MagicMock()
mocker.patch(
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
return_value=queue_manager,
)
thread_obj = mocker.MagicMock()
mocker.patch(
"core.app.apps.agent_chat.app_generator.threading.Thread",
return_value=thread_obj,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
return_value={"result": "ok"},
)
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
return_value=app_entity,
)
args = {
"query": "hello",
"inputs": {"name": "world"},
"conversation_id": "conv",
"model_config": {"model": {"provider": "p"}},
"files": [{"id": "f1"}],
}
result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True)
assert result == {"result": "ok"}
thread_obj.start.assert_called_once()
def test_generate_without_file_config(self, generator, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
user = DummyAccount("user")
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
generator._init_generate_records = mocker.MagicMock(
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
)
generator._handle_response = mocker.MagicMock(return_value="response")
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
return_value=None,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
return_value=["file-obj"],
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
return_value=mocker.MagicMock(),
)
thread_obj = mocker.MagicMock()
mocker.patch(
"core.app.apps.agent_chat.app_generator.threading.Thread",
return_value=thread_obj,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
return_value={"result": "ok"},
)
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
return_value=app_entity,
)
args = {"query": "hello", "inputs": {"name": "world"}}
result = generator.generate(
app_model=app_model,
user=user,
args=args,
invoke_from=InvokeFrom.WEB_APP,
streaming=True,
)
assert result == {"result": "ok"}
class TestAgentChatAppGeneratorWorker:
@pytest.fixture(autouse=True)
def patch_context(self, mocker):
@contextlib.contextmanager
def ctx_manager(*args, **kwargs):
yield
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
runner = mocker.MagicMock()
runner.run.side_effect = GenerateTaskStoppedError()
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
generator._generate_worker(
flask_app=mocker.MagicMock(),
context=mocker.MagicMock(),
application_generate_entity=mocker.MagicMock(),
queue_manager=queue_manager,
conversation_id="conv",
message_id="msg",
)
queue_manager.publish_error.assert_not_called()
@pytest.mark.parametrize(
"error",
[
InvokeAuthorizationError("bad"),
ValidationError.from_exception_data("TestModel", []),
ValueError("bad"),
Exception("bad"),
],
)
def test_generate_worker_publishes_errors(self, generator, mocker, error):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
runner = mocker.MagicMock()
runner.run.side_effect = error
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
generator._generate_worker(
flask_app=mocker.MagicMock(),
context=mocker.MagicMock(),
application_generate_entity=mocker.MagicMock(),
queue_manager=queue_manager,
conversation_id="conv",
message_id="msg",
)
assert queue_manager.publish_error.called
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
runner = mocker.MagicMock()
runner.run.side_effect = ValueError("bad")
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True))
logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger")
generator._generate_worker(
flask_app=mocker.MagicMock(),
context=mocker.MagicMock(),
application_generate_entity=mocker.MagicMock(),
queue_manager=queue_manager,
conversation_id="conv",
message_id="msg",
)
logger.exception.assert_called_once()

View File

@@ -1,413 +0,0 @@
import pytest
from core.agent.entities import AgentEntity
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.moderation.base import ModerationError
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
@pytest.fixture
def runner():
return AgentChatAppRunner()
class TestAgentChatAppRunnerRun:
def test_run_app_not_found(self, runner, mocker):
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_moderation_error_direct_output(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(),
conversation_id=None,
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
mocker.patch.object(runner, "direct_output")
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
runner.direct_output.assert_called_once()
def test_run_annotation_reply_short_circuits(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(),
conversation_id=None,
user_id="user",
invoke_from=mocker.MagicMock(),
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
annotation = mocker.MagicMock(id="anno", content="answer")
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation)
mocker.patch.object(runner, "direct_output")
queue_manager = mocker.MagicMock()
runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock())
queue_manager.publish.assert_called_once()
runner.direct_output.assert_called_once()
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(),
conversation_id=None,
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=True)
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_model_schema_missing(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = None
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
@pytest.mark.parametrize(
("mode", "expected_runner"),
[
(LLMMode.CHAT, "CotChatAgentRunner"),
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
],
)
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = []
model_schema.model_properties = {ModelPropertyKey.MODE: mode}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
runner_cls = mocker.MagicMock()
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
runner_instance = mocker.MagicMock()
runner_cls.return_value = runner_instance
runner_instance.run.return_value = []
mocker.patch.object(runner, "_handle_invoke_result")
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
runner_instance.run.assert_called_once()
runner._handle_invoke_result.assert_called_once()
def test_run_invalid_llm_mode_raises(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = []
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = [ModelFeature.TOOL_CALL]
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
runner_cls = mocker.MagicMock()
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
runner_instance = mocker.MagicMock()
runner_cls.return_value = runner_instance
runner_instance.run.return_value = []
mocker.patch.object(runner, "_handle_invoke_result")
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
runner_instance.run.assert_called_once()
def test_run_conversation_not_found(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, None],
)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_message_not_found(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = []
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)

View File

@@ -1,162 +0,0 @@
from collections.abc import Generator
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageStreamResponse,
PingStreamResponse,
)
class TestAgentChatAppGenerateResponseConverterBlocking:
def test_convert_blocking_full_response(self):
blocking = ChatbotAppBlockingResponse(
task_id="task",
data=ChatbotAppBlockingResponse.Data(
id="id",
mode="agent-chat",
conversation_id="conv",
message_id="msg",
answer="answer",
metadata={"a": 1},
created_at=123,
),
)
result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert result["event"] == "message"
assert result["answer"] == "answer"
assert result["metadata"] == {"a": 1}
def test_convert_blocking_simple_response_with_dict_metadata(self):
blocking = ChatbotAppBlockingResponse(
task_id="task",
data=ChatbotAppBlockingResponse.Data(
id="id",
mode="agent-chat",
conversation_id="conv",
message_id="msg",
answer="answer",
metadata={
"retriever_resources": [
{
"segment_id": "s1",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "content",
}
],
"annotation_reply": {"id": "a"},
"usage": {"prompt_tokens": 1},
},
created_at=123,
),
)
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "annotation_reply" not in result["metadata"]
assert "usage" not in result["metadata"]
def test_convert_blocking_simple_response_with_non_dict_metadata(self):
blocking = ChatbotAppBlockingResponse.model_construct(
task_id="task",
data=ChatbotAppBlockingResponse.Data.model_construct(
id="id",
mode="agent-chat",
conversation_id="conv",
message_id="msg",
answer="answer",
metadata="bad",
created_at=123,
),
)
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert result["metadata"] == {}
class TestAgentChatAppGenerateResponseConverterStream:
def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]:
def _gen():
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=1,
stream_response=PingStreamResponse(task_id="t"),
)
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=2,
stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"),
)
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=3,
stream_response=MessageEndStreamResponse(
task_id="t",
id="m1",
metadata={
"retriever_resources": [
{
"segment_id": "s1",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "content",
"summary": "summary",
"extra": "ignored",
}
],
"annotation_reply": {"id": "a"},
"usage": {"prompt_tokens": 1},
},
),
)
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=4,
stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")),
)
return _gen()
def test_convert_stream_full_response(self):
items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream()))
assert items[0] == "ping"
assert items[1]["event"] == "message"
assert "answer" in items[1]
assert items[2]["event"] == "message_end"
assert items[3]["event"] == "error"
def test_convert_stream_simple_response(self):
items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream()))
assert items[0] == "ping"
# Assert the message event structure and content at items[1]
assert items[1]["event"] == "message"
assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"]
assert items[2]["event"] == "message_end"
assert "metadata" in items[2]
metadata = items[2]["metadata"]
assert "annotation_reply" not in metadata
assert "usage" not in metadata
assert metadata["retriever_resources"] == [
{
"segment_id": "s1",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "content",
"summary": "summary",
}
]
assert items[3]["event"] == "error"

View File

@@ -1,113 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from models.model import AppMode
class TestChatAppConfigManager:
def test_get_app_config_uses_override_dict(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"})
override = {"model": "override"}
model_entity = ModelConfigEntity(provider="p", model="m")
prompt_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="hi",
)
with (
patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity),
patch(
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity
),
patch(
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
return_value=None,
),
patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None),
patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])),
):
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=None,
override_config_dict=override,
)
assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
assert app_config.app_model_config_dict == override
assert app_config.app_mode == AppMode.CHAT
def test_config_validate_filters_related_keys(self):
config = {"extra": 1}
def _add_key(key, value):
def _inner(*args, **kwargs):
config = args[-1]
config = {**config, key: value}
return config, [key]
return _inner
with (
patch(
"core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
side_effect=_add_key("model", 1),
),
patch(
"core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
side_effect=_add_key("inputs", 2),
),
patch(
"core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=_add_key("file_upload", 3),
),
patch(
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
side_effect=_add_key("prompt", 4),
),
patch(
"core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
side_effect=_add_key("dataset", 5),
),
patch(
"core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
side_effect=_add_key("opening_statement", 6),
),
patch(
"core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
side_effect=_add_key("suggested_questions_after_answer", 7),
),
patch(
"core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
side_effect=_add_key("speech_to_text", 8),
),
patch(
"core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=_add_key("text_to_speech", 9),
),
patch(
"core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
side_effect=_add_key("retriever_resource", 10),
),
patch(
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=_add_key("sensitive_word_avoidance", 11),
),
):
filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config)
assert filtered["model"] == 1
assert filtered["inputs"] == 2
assert filtered["file_upload"] == 3
assert filtered["prompt"] == 4
assert filtered["dataset"] == 5
assert filtered["opening_statement"] == 6
assert filtered["suggested_questions_after_answer"] == 7
assert filtered["speech_to_text"] == 8
assert filtered["text_to_speech"] == 9
assert filtered["retriever_resource"] == 10
assert filtered["sensitive_word_avoidance"] == 11

View File

@@ -1,280 +0,0 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.moderation.base import ModerationError
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from models.model import AppMode
class DummyGenerateEntity:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class DummyQueueManager:
def __init__(self, *args, **kwargs):
self.published = []
def publish_error(self, error, pub_from):
self.published.append((error, pub_from))
def publish(self, event, pub_from):
self.published.append((event, pub_from))
class TestChatAppGenerator:
def test_generate_requires_query(self):
generator = ChatAppGenerator()
with pytest.raises(ValueError):
generator.generate(
app_model=SimpleNamespace(),
user=SimpleNamespace(),
args={"inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
def test_generate_rejects_non_string_query(self):
generator = ChatAppGenerator()
with pytest.raises(ValueError):
generator.generate(
app_model=SimpleNamespace(),
user=SimpleNamespace(),
args={"query": 1, "inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
def test_generate_debugger_overrides_model_config(self):
generator = ChatAppGenerator()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
user = SimpleNamespace(id="user-1", session_id="session-1")
args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}
with (
patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None),
patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}),
patch(
"core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config",
return_value=SimpleNamespace(
variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT
),
),
patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None),
patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]),
patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity),
patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager),
patch(
"core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True}
),
patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})),
patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}),
patch.object(
ChatAppGenerator,
"_init_generate_records",
return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")),
),
patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}),
patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f),
patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread,
):
mock_thread.return_value.start.return_value = None
result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False)
assert result == {"ok": True}
def test_generate_rejects_model_config_override_for_non_debugger(self):
generator = ChatAppGenerator()
with pytest.raises(ValueError):
with (
patch.object(
ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})
),
):
generator.generate(
app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value),
user=SimpleNamespace(id="u1", session_id="s1"),
args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
def test_generate_worker_handles_exceptions(self):
generator = ChatAppGenerator()
queue_manager = DummyQueueManager()
entity = DummyGenerateEntity(task_id="t1", user_id="u1")
with (
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()),
patch("core.app.apps.chat.app_generator.db.session.close"),
):
generator._generate_worker(
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
application_generate_entity=entity,
queue_manager=queue_manager,
conversation_id="c1",
message_id="m1",
)
assert queue_manager.published
with (
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()),
patch("core.app.apps.chat.app_generator.db.session.close"),
):
generator._generate_worker(
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
application_generate_entity=entity,
queue_manager=queue_manager,
conversation_id="c1",
message_id="m1",
)
class TestChatAppRunner:
def test_run_raises_when_app_missing(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[]
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
with pytest.raises(ValueError):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
def test_run_moderation_error_direct_output(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
patch.object(ChatAppRunner, "direct_output") as mock_direct,
):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
mock_direct.assert_called_once()
def test_run_annotation_reply_short_circuits(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
annotation = SimpleNamespace(id="ann-1", content="answer")
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
patch.object(ChatAppRunner, "direct_output") as mock_direct,
):
queue_manager = DummyQueueManager()
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published)
mock_direct.assert_called_once()
def test_run_returns_when_hosting_moderation_blocks(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))

View File

@@ -1,65 +0,0 @@
from collections.abc import Generator
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageStreamResponse,
PingStreamResponse,
)
class TestChatAppGenerateResponseConverter:
def test_convert_blocking_simple_response_metadata(self):
data = ChatbotAppBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
answer="hi",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
)
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
def test_convert_stream_responses(self):
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=PingStreamResponse(task_id="t1"),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
)
full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream()))
assert full[0] == "ping"
assert full[1]["event"] == "message"
assert full[2]["event"] == "error"
simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert simple[0] == "ping"
assert simple[-1]["event"] == "message_end"

View File

@@ -1,162 +0,0 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.app.apps.completion.app_runner as module
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.moderation.base import ModerationError
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
@pytest.fixture
def runner():
return CompletionAppRunner()
def _build_app_config(dataset=None, external_tools=None, additional_features=None):
app_config = MagicMock()
app_config.app_id = "app1"
app_config.tenant_id = "tenant"
app_config.prompt_template = MagicMock()
app_config.dataset = dataset
app_config.external_data_variables = external_tools or []
app_config.additional_features = additional_features
app_config.app_model_config_dict = {"file_upload": {"enabled": True}}
return app_config
def _build_generate_entity(app_config, file_upload_config=None):
model_conf = MagicMock(
provider_model_bundle="bundle",
model="model",
parameters={"max_tokens": 10},
stop=["stop"],
)
return SimpleNamespace(
app_config=app_config,
model_conf=model_conf,
inputs={"qvar": "query_from_input"},
query="original_query",
files=[],
file_upload_config=file_upload_config,
stream=True,
user_id="user",
invoke_from=MagicMock(),
)
class TestCompletionAppRunner:
def test_run_app_not_found(self, runner, mocker):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
def test_run_moderation_error_outputs_direct(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked"))
runner.direct_output = MagicMock()
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner.direct_output.assert_called_once()
runner._handle_invoke_result.assert_not_called()
def test_run_hosting_moderation_stops(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=True)
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner._handle_invoke_result.assert_not_called()
def test_run_dataset_and_external_tools_flow(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
session.close = MagicMock()
mocker.patch.object(module.db, "session", session)
retrieve_config = MagicMock(query_variable="qvar")
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
additional_features = MagicMock(show_retrieve_source=True)
app_config = _build_app_config(
dataset=dataset_config,
external_tools=["tool"],
additional_features=additional_features,
)
file_upload_config = MagicMock()
file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH
app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config)
runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])])
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs)
runner.check_hosting_moderation = MagicMock(return_value=False)
runner.recalc_llm_max_tokens = MagicMock()
runner._handle_invoke_result = MagicMock()
dataset_retrieval = MagicMock()
dataset_retrieval.retrieve.return_value = ("ctx", ["file1"])
mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval)
model_instance = MagicMock()
model_instance.invoke_llm.return_value = "invoke_result"
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
dataset_retrieval.retrieve.assert_called_once()
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
runner._handle_invoke_result.assert_called_once()
def test_run_uses_low_image_detail_default(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=True)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
assert (
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
== ImagePromptMessageContent.DETAIL.LOW
)

View File

@@ -1,122 +0,0 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import core.app.apps.completion.app_config_manager as module
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode
class TestCompletionAppConfigManager:
def test_get_app_config_with_override(self, mocker):
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
app_model_config = MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
override_config = {"model": {"provider": "override"}}
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"]))
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
result = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_config,
)
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
assert result.app_model_config_dict == override_config
assert result.variables == ["v1"]
assert result.external_data_variables == ["ext1"]
assert result.app_mode == AppMode.COMPLETION
def test_get_app_config_without_override_uses_model_config(self, mocker):
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
app_model_config = MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], []))
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
assert result.app_model_config_dict == {"model": {"provider": "x"}}
def test_config_validate_filters_related_keys(self, mocker):
config = {
"model": {"provider": "x"},
"variables": ["v"],
"file_upload": {"enabled": True},
"prompt": {"template": "t"},
"dataset": {"enabled": True},
"tts": {"enabled": True},
"more_like_this": {"enabled": True},
"moderation": {"enabled": True},
"extra": "drop",
}
mocker.patch.object(
module.ModelConfigManager,
"validate_and_set_defaults",
return_value=(config, ["model"]),
)
mocker.patch.object(
module.BasicVariablesConfigManager,
"validate_and_set_defaults",
return_value=(config, ["variables"]),
)
mocker.patch.object(
module.FileUploadConfigManager,
"validate_and_set_defaults",
return_value=(config, ["file_upload"]),
)
mocker.patch.object(
module.PromptTemplateConfigManager,
"validate_and_set_defaults",
return_value=(config, ["prompt"]),
)
mocker.patch.object(
module.DatasetConfigManager,
"validate_and_set_defaults",
return_value=(config, ["dataset"]),
)
mocker.patch.object(
module.TextToSpeechConfigManager,
"validate_and_set_defaults",
return_value=(config, ["tts"]),
)
mocker.patch.object(
module.MoreLikeThisConfigManager,
"validate_and_set_defaults",
return_value=(config, ["more_like_this"]),
)
mocker.patch.object(
module.SensitiveWordAvoidanceConfigManager,
"validate_and_set_defaults",
return_value=(config, ["moderation"]),
)
filtered = CompletionAppConfigManager.config_validate("tenant", config)
assert "extra" not in filtered
assert set(filtered.keys()) == {
"model",
"variables",
"file_upload",
"prompt",
"dataset",
"tts",
"more_like_this",
"moderation",
}

View File

@@ -1,321 +0,0 @@
import contextlib
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
import core.app.apps.completion.app_generator as module
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
@pytest.fixture
def generator(mocker):
gen = CompletionAppGenerator()
mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn)
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app)))
thread = MagicMock()
mocker.patch.object(module.threading, "Thread", return_value=thread)
mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock())
mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock())
mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
return gen
def _build_app_model():
return MagicMock(tenant_id="tenant", id="app1", mode="completion")
def _build_user():
return MagicMock(id="user", session_id="session")
def _build_app_model_config():
config = MagicMock(id="cfg")
config.to_dict.return_value = {"model": {"provider": "x"}}
return config
class TestCompletionAppGenerator:
def test_generate_invalid_query_type(self, generator):
with pytest.raises(ValueError):
generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": 123, "inputs": {}, "files": []},
invoke_from=InvokeFrom.WEB_APP,
streaming=True,
)
def test_generate_override_not_debugger(self, generator):
with pytest.raises(ValueError):
generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {}, "files": [], "model_config": {}},
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
)
def test_generate_success_no_file_config(self, generator, mocker):
app_model_config = _build_app_model_config()
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
mocker.patch.object(module.file_factory, "build_from_mappings")
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
conversation = MagicMock(id="conv", mode="completion")
message = MagicMock(id="msg")
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
result = generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {"a": 1}, "files": []},
invoke_from=InvokeFrom.WEB_APP,
streaming=True,
)
assert result == "converted"
module.file_factory.build_from_mappings.assert_not_called()
def test_generate_success_with_files(self, generator, mocker):
app_model_config = _build_app_model_config()
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
file_extra_config = MagicMock()
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
conversation = MagicMock(id="conv", mode="completion")
message = MagicMock(id="msg")
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
result = generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]},
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
)
assert result == "converted"
module.file_factory.build_from_mappings.assert_called_once()
def test_generate_override_model_config_debugger(self, generator, mocker):
app_model_config = _build_app_model_config()
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
override_config = {"model": {"provider": "override"}}
mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config)
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
get_app_config = mocker.patch.object(
module.CompletionAppConfigManager,
"get_app_config",
return_value=app_config,
)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch.object(
generator,
"_init_generate_records",
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
)
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {}, "files": [], "model_config": override_config},
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
assert get_app_config.call_args.kwargs["override_config_dict"] == override_config
def test_generate_more_like_this_message_not_found(self, generator, mocker):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
with pytest.raises(MessageNotExistsError):
generator.generate_more_like_this(
app_model=_build_app_model(),
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_disabled(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False})
message = MagicMock()
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
with pytest.raises(MoreLikeThisDisabledError):
generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_app_model_config_missing(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = None
message = MagicMock()
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
with pytest.raises(MoreLikeThisDisabledError):
generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_message_config_none(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
message = MagicMock(app_model_config=None)
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
with pytest.raises(ValueError):
generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_success(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
message = MagicMock()
message.message_files = [{"id": "f"}]
message.inputs = {"a": 1}
message.query = "q"
app_model_config = MagicMock()
app_model_config.to_dict.return_value = {
"model": {"completion_params": {"temperature": 0.1}},
"file_upload": {"enabled": True},
}
message.app_model_config = app_model_config
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
file_extra_config = MagicMock()
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
get_app_config = mocker.patch.object(
module.CompletionAppConfigManager,
"get_app_config",
return_value=app_config,
)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(
generator,
"_init_generate_records",
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
)
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
result = generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
stream=True,
)
assert result == "converted"
override_dict = get_app_config.call_args.kwargs["override_config_dict"]
assert override_dict["model"]["completion_params"]["temperature"] == 0.9
@pytest.mark.parametrize(
("error", "should_publish"),
[
(GenerateTaskStoppedError(), False),
(InvokeAuthorizationError("bad"), True),
(
ValidationError.from_exception_data(
"Model",
[{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}],
),
True,
),
(ValueError("bad"), True),
(RuntimeError("boom"), True),
],
)
def test_generate_worker_error_handling(self, generator, mocker, error, should_publish):
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
session = mocker.MagicMock()
mocker.patch.object(module.db, "session", session)
mocker.patch.object(generator, "_get_message", return_value=MagicMock())
runner_instance = MagicMock()
runner_instance.run.side_effect = error
mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance)
queue_manager = MagicMock()
generator._generate_worker(
flask_app=flask_app,
application_generate_entity=MagicMock(),
queue_manager=queue_manager,
message_id="msg",
)
assert queue_manager.publish_error.called is should_publish

View File

@@ -1,153 +0,0 @@
from collections.abc import Generator
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageStreamResponse,
PingStreamResponse,
)
class TestCompletionAppGenerateResponseConverter:
def test_convert_blocking_full_response(self):
blocking = CompletionAppBlockingResponse(
task_id="task",
data=CompletionAppBlockingResponse.Data(
id="id",
mode="completion",
message_id="msg",
answer="answer",
metadata={"k": "v"},
created_at=123,
),
)
result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert result["event"] == "message"
assert result["task_id"] == "task"
assert result["message_id"] == "msg"
assert result["answer"] == "answer"
assert result["metadata"] == {"k": "v"}
def test_convert_blocking_simple_response_metadata_simplified(self):
metadata = {
"retriever_resources": [
{
"segment_id": "s",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "c",
"summary": "sum",
"extra": "x",
}
],
"annotation_reply": {"a": 1},
"usage": {"t": 2},
}
blocking = CompletionAppBlockingResponse(
task_id="task",
data=CompletionAppBlockingResponse.Data(
id="id",
mode="completion",
message_id="msg",
answer="answer",
metadata=metadata,
created_at=123,
),
)
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "annotation_reply" not in result["metadata"]
assert "usage" not in result["metadata"]
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
assert "extra" not in result["metadata"]["retriever_resources"][0]
def test_convert_blocking_simple_response_metadata_not_dict(self):
data = CompletionAppBlockingResponse.Data.model_construct(
id="id",
mode="completion",
message_id="msg",
answer="answer",
metadata="bad",
created_at=123,
)
blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data)
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert result["metadata"] == {}
def test_convert_stream_full_response(self):
def stream() -> Generator[AppStreamResponse, None, None]:
yield CompletionAppStreamResponse(
stream_response=PingStreamResponse(task_id="t"),
message_id="m",
created_at=1,
)
yield CompletionAppStreamResponse(
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
message_id="m",
created_at=2,
)
yield CompletionAppStreamResponse(
stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"),
message_id="m",
created_at=3,
)
result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream()))
assert result[0] == "ping"
assert result[1]["event"] == "error"
assert result[1]["code"] == "invalid_param"
assert result[2]["event"] == "message"
def test_convert_stream_simple_response(self):
def stream() -> Generator[AppStreamResponse, None, None]:
yield CompletionAppStreamResponse(
stream_response=PingStreamResponse(task_id="t"),
message_id="m",
created_at=1,
)
yield CompletionAppStreamResponse(
stream_response=MessageEndStreamResponse(
task_id="t",
id="end",
metadata={
"retriever_resources": [
{
"segment_id": "s",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "c",
"summary": "sum",
}
],
"annotation_reply": {"a": 1},
"usage": {"t": 2},
},
),
message_id="m",
created_at=2,
)
yield CompletionAppStreamResponse(
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
message_id="m",
created_at=3,
)
result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert result[0] == "ping"
assert result[1]["event"] == "message_end"
assert "annotation_reply" not in result[1]["metadata"]
assert "usage" not in result[1]["metadata"]
assert result[2]["event"] == "error"

View File

@@ -1,55 +0,0 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import core.app.apps.pipeline.pipeline_config_manager as module
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from models.model import AppMode
def test_get_pipeline_config(mocker):
pipeline = MagicMock(tenant_id="tenant", id="pipe1")
workflow = MagicMock(id="wf1")
mocker.patch.object(
module.WorkflowVariablesConfigManager,
"convert_rag_pipeline_variable",
return_value=["var1"],
)
mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start")
assert result.tenant_id == "tenant"
assert result.app_id == "pipe1"
assert result.workflow_id == "wf1"
assert result.app_mode == AppMode.RAG_PIPELINE
assert result.rag_pipeline_variables == ["var1"]
def test_config_validate_filters_related_keys(mocker):
config = {
"file_upload": {"enabled": True},
"tts": {"enabled": True},
"moderation": {"enabled": True},
"extra": "drop",
}
mocker.patch.object(
module.FileUploadConfigManager,
"validate_and_set_defaults",
return_value=(config, ["file_upload"]),
)
mocker.patch.object(
module.TextToSpeechConfigManager,
"validate_and_set_defaults",
return_value=(config, ["tts"]),
)
mocker.patch.object(
module.SensitiveWordAvoidanceConfigManager,
"validate_and_set_defaults",
return_value=(config, ["moderation"]),
)
filtered = PipelineConfigManager.config_validate("tenant", config)
assert set(filtered.keys()) == {"file_upload", "tts", "moderation"}

View File

@@ -1,111 +0,0 @@
from collections.abc import Generator
from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
def test_convert_blocking_full_and_simple_response():
blocking = WorkflowAppBlockingResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowAppBlockingResponse.Data(
id="id",
workflow_id="wf",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"k": "v"},
error=None,
elapsed_time=0.1,
total_tokens=10,
total_steps=1,
created_at=1,
finished_at=2,
),
)
full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert full == simple
assert full["workflow_run_id"] == "run"
assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED
def test_convert_stream_full_response():
def stream() -> Generator[AppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(
stream_response=PingStreamResponse(task_id="t"),
workflow_run_id="run",
)
yield WorkflowAppStreamResponse(
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
workflow_run_id="run",
)
result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream()))
assert result[0] == "ping"
assert result[1]["event"] == "error"
assert result[1]["code"] == "invalid_param"
def test_convert_stream_simple_response_node_ignore_details():
node_start = NodeStartStreamResponse(
task_id="t",
workflow_run_id="run",
data=NodeStartStreamResponse.Data(
id="nid",
node_id="node",
node_type="type",
title="Title",
index=1,
predecessor_node_id=None,
inputs={"a": 1},
inputs_truncated=False,
created_at=1,
),
)
node_finish = NodeFinishStreamResponse(
task_id="t",
workflow_run_id="run",
data=NodeFinishStreamResponse.Data(
id="nid",
node_id="node",
node_type="type",
title="Title",
index=1,
predecessor_node_id=None,
inputs={"a": 1},
inputs_truncated=False,
process_data=None,
process_data_truncated=False,
outputs={"b": 2},
outputs_truncated=False,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=0.1,
execution_metadata=None,
created_at=1,
finished_at=2,
files=[],
),
)
def stream() -> Generator[AppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run")
yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run")
result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert result[0]["event"] == "node_started"
assert result[0]["data"]["inputs"] is None
assert result[1]["event"] == "node_finished"
assert result[1]["data"]["inputs"] is None

View File

@@ -1,699 +0,0 @@
import contextlib
from types import SimpleNamespace
from unittest.mock import MagicMock, PropertyMock
import pytest
import core.app.apps.pipeline.pipeline_generator as module
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceProviderType
class FakeRagPipelineGenerateEntity(SimpleNamespace):
class SingleIterationRunEntity(SimpleNamespace):
pass
class SingleLoopRunEntity(SimpleNamespace):
pass
def model_dump(self):
return dict(self.__dict__)
@pytest.fixture
def generator(mocker):
gen = module.PipelineGenerator()
mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity)
mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs)
mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock()))
mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock()))
return gen
def _build_pipeline_dataset():
return SimpleNamespace(
id="ds",
name="dataset",
description="desc",
chunk_structure="chunk",
built_in_field_enabled=True,
tenant_id="tenant",
)
def _build_pipeline():
pipeline = MagicMock(tenant_id="tenant", id="pipe")
pipeline.retrieve_dataset.return_value = _build_pipeline_dataset()
return pipeline
def _build_workflow():
return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant")
def _build_user():
return MagicMock(id="user", name="User", session_id="session")
def _build_args():
return {
"inputs": {"k": "v"},
"start_node_id": "start",
"datasource_type": DatasourceProviderType.LOCAL_FILE.value,
"datasource_info_list": [{"name": "file"}],
}
def _patch_session(mocker, session):
mocker.patch.object(module, "Session", return_value=session)
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
def _dummy_preserve(*args, **kwargs):
return contextlib.nullcontext()
class DummySession:
def __init__(self):
self.scalar = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def test_generate_dataset_missing(generator, mocker):
pipeline = _build_pipeline()
pipeline.retrieve_dataset.return_value = None
session = DummySession()
_patch_session(mocker, session)
with pytest.raises(ValueError):
generator.generate(
pipeline=pipeline,
workflow=_build_workflow(),
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
)
def test_generate_debugger_calls_generate(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
generator,
"_format_datasource_info_list",
return_value=[{"name": "file"}],
)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
)
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
result = generator.generate(
pipeline=pipeline,
workflow=workflow,
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
assert result == {"result": "ok"}
def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
session = DummySession()
_patch_session(mocker, session)
datasource_info_list = [{"name": "file1"}, {"name": "file2"}]
mocker.patch.object(
generator,
"_format_datasource_info_list",
return_value=datasource_info_list,
)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
)
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1)
document1 = SimpleNamespace(
id="doc1",
position=1,
data_source_type=DatasourceProviderType.LOCAL_FILE,
data_source_info="{}",
name="file1",
indexing_status="",
error=None,
enabled=True,
)
document2 = SimpleNamespace(
id="doc2",
position=2,
data_source_type=DatasourceProviderType.LOCAL_FILE,
data_source_info="{}",
name="file2",
indexing_status="",
error=None,
enabled=True,
)
mocker.patch.object(generator, "_build_document", side_effect=[document1, document2])
mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock())
db_session = MagicMock()
mocker.patch.object(module.db, "session", db_session)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
task_proxy = MagicMock()
mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy)
result = generator.generate(
pipeline=pipeline,
workflow=workflow,
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
streaming=False,
)
assert result["batch"]
assert len(result["documents"]) == 2
task_proxy.delay.assert_called_once()
def test_generate_is_retry_calls_generate(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
generator,
"_format_datasource_info_list",
return_value=[{"name": "file"}],
)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
)
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
result = generator.generate(
pipeline=pipeline,
workflow=workflow,
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
streaming=True,
is_retry=True,
)
assert result == {"result": "ok"}
def test_generate_worker_handles_errors(generator, mocker):
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
application_generate_entity = FakeRagPipelineGenerateEntity(
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
invoke_from=InvokeFrom.WEB_APP,
user_id="user",
)
session = DummySession()
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
_patch_session(mocker, session)
runner_instance = MagicMock()
runner_instance.run.side_effect = ValueError("bad")
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
queue_manager = MagicMock()
generator._generate_worker(
flask_app=flask_app,
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
context=contextlib.nullcontext(),
variable_loader=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
queue_manager.publish_error.assert_called_once()
def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker):
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
application_generate_entity = FakeRagPipelineGenerateEntity(
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
invoke_from=InvokeFrom.WEB_APP,
user_id="user",
)
session = DummySession()
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
_patch_session(mocker, session)
runner_instance = MagicMock()
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
generator._generate_worker(
flask_app=flask_app,
application_generate_entity=application_generate_entity,
queue_manager=MagicMock(),
context=contextlib.nullcontext(),
variable_loader=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session"
def test_generate_raises_when_workflow_not_found(generator, mocker):
flask_app = MagicMock()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(module.db, "session", session)
with pytest.raises(ValueError):
generator._generate(
flask_app=flask_app,
context=contextlib.nullcontext(),
pipeline=_build_pipeline(),
workflow_id="wf",
user=_build_user(),
application_generate_entity=FakeRagPipelineGenerateEntity(
task_id="t",
app_config=SimpleNamespace(app_id="pipe"),
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
),
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
)
def test_generate_success_returns_converted(generator, mocker):
flask_app = MagicMock()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = workflow
mocker.patch.object(module.db, "session", session)
queue_manager = MagicMock()
mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager)
worker_thread = MagicMock()
mocker.patch.object(module.threading, "Thread", return_value=worker_thread)
mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock())
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted")
result = generator._generate(
flask_app=flask_app,
context=contextlib.nullcontext(),
pipeline=_build_pipeline(),
workflow_id="wf",
user=_build_user(),
application_generate_entity=FakeRagPipelineGenerateEntity(
task_id="t",
app_config=SimpleNamespace(app_id="pipe"),
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
),
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
)
assert result == "converted"
def test_single_iteration_generate_validates_inputs(generator, mocker):
with pytest.raises(ValueError):
generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {})
with pytest.raises(ValueError):
generator.single_iteration_generate(
_build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None}
)
def test_single_iteration_generate_dataset_required(generator, mocker):
pipeline = _build_pipeline()
pipeline.retrieve_dataset.return_value = None
session = DummySession()
_patch_session(mocker, session)
with pytest.raises(ValueError):
generator.single_iteration_generate(
pipeline,
_build_workflow(),
"node",
_build_user(),
{"inputs": {"a": 1}},
)
def test_single_iteration_generate_success(generator, mocker):
pipeline = _build_pipeline()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
mocker.patch.object(generator, "_generate", return_value={"ok": True})
result = generator.single_iteration_generate(
pipeline,
_build_workflow(),
"node",
_build_user(),
{"inputs": {"a": 1}},
streaming=False,
)
assert result == {"ok": True}
def test_single_loop_generate_success(generator, mocker):
pipeline = _build_pipeline()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
mocker.patch.object(generator, "_generate", return_value={"ok": True})
result = generator.single_loop_generate(
pipeline,
_build_workflow(),
"node",
_build_user(),
{"inputs": {"a": 1}},
streaming=False,
)
assert result == {"ok": True}
def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
app_entity = FakeRagPipelineGenerateEntity(task_id="t")
task_pipeline = MagicMock()
task_pipeline.process.side_effect = ValueError("I/O operation on closed file.")
mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline)
with pytest.raises(GenerateTaskStoppedError):
generator._handle_response(
application_generate_entity=app_entity,
workflow=workflow,
queue_manager=MagicMock(),
user=_build_user(),
draft_var_saver_factory=MagicMock(),
stream=False,
)
def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
class DummyDocument(SimpleNamespace):
pass
mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs))
document = generator._build_document(
tenant_id="tenant",
dataset_id="ds",
built_in_field_enabled=True,
datasource_type=DatasourceProviderType.LOCAL_FILE,
datasource_info={"name": "file"},
created_from="rag-pipeline",
position=1,
account=_build_user(),
batch="batch",
document_form="text",
)
assert document.name == "file"
assert document.doc_metadata
def test_build_document_invalid_datasource_type(generator):
with pytest.raises(ValueError):
generator._build_document(
tenant_id="tenant",
dataset_id="ds",
built_in_field_enabled=False,
datasource_type="invalid",
datasource_info={},
created_from="rag-pipeline",
position=1,
account=_build_user(),
batch="batch",
document_form="text",
)
def test_format_datasource_info_list_non_online_drive(generator):
result = generator._format_datasource_info_list(
DatasourceProviderType.LOCAL_FILE,
[{"name": "file"}],
_build_pipeline(),
_build_workflow(),
"start",
_build_user(),
)
assert result == [{"name": "file"}]
def test_format_datasource_info_list_missing_node_data(generator):
workflow = MagicMock(graph_dict={"nodes": []})
with pytest.raises(ValueError):
generator._format_datasource_info_list(
DatasourceProviderType.ONLINE_DRIVE,
[],
_build_pipeline(),
workflow,
"start",
_build_user(),
)
def test_format_datasource_info_list_online_drive_folder(generator, mocker):
workflow = MagicMock(
graph_dict={
"nodes": [
{
"id": "start",
"data": {
"plugin_id": "p",
"provider_name": "provider",
"datasource_name": "drive",
"credential_id": "cred",
},
}
]
}
)
runtime = MagicMock()
runtime.runtime = SimpleNamespace(credentials=None)
runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE
mocker.patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=runtime,
)
mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"})
mocker.patch.object(
generator,
"_get_files_in_folder",
side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}),
)
result = generator._format_datasource_info_list(
DatasourceProviderType.ONLINE_DRIVE,
[{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}],
_build_pipeline(),
workflow,
"start",
_build_user(),
)
assert result == [{"id": "f"}]
def test_get_files_in_folder_recurses_and_collects(generator):
class File:
def __init__(self, id, name, type):
self.id = id
self.name = name
self.type = type
class FilesPage:
def __init__(self, files, is_truncated=False, next_page_parameters=None):
self.files = files
self.is_truncated = is_truncated
self.next_page_parameters = next_page_parameters
class Result:
def __init__(self, result):
self.result = result
class Runtime:
def __init__(self):
self.calls = []
def datasource_provider_type(self):
return DatasourceProviderType.ONLINE_DRIVE
def online_drive_browse_files(self, user_id, request, provider_type):
self.calls.append(request.next_page_parameters)
if request.prefix == "fd":
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
if request.next_page_parameters is None:
return iter(
[
Result(
[FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})]
)
]
)
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
runtime = Runtime()
all_files = []
generator._get_files_in_folder(
datasource_runtime=runtime,
prefix="root",
bucket="b",
user_id="user",
all_files=all_files,
datasource_info={},
)
assert {f["id"] for f in all_files} == {"f1", "f2"}

View File

@@ -1,57 +0,0 @@
import pytest
import core.app.apps.pipeline.pipeline_queue_manager as module
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
)
from dify_graph.model_runtime.entities.llm_entities import LLMResult
def test_publish_sets_stop_listen_and_raises_on_stopped(mocker):
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
manager._q = mocker.MagicMock()
manager.stop_listen = mocker.MagicMock()
manager._is_stopped = mocker.MagicMock(return_value=True)
with pytest.raises(GenerateTaskStoppedError):
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER)
manager.stop_listen.assert_called_once()
def test_publish_stop_events_trigger_stop_listen(mocker):
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
manager._q = mocker.MagicMock()
manager.stop_listen = mocker.MagicMock()
manager._is_stopped = mocker.MagicMock(return_value=False)
for event in [
QueueErrorEvent(error=ValueError("bad")),
QueueMessageEndEvent(llm_result=LLMResult.model_construct()),
QueueWorkflowSucceededEvent(),
QueueWorkflowFailedEvent(error="failed", exceptions_count=1),
QueueWorkflowPartialSuccessEvent(exceptions_count=1),
]:
manager.stop_listen.reset_mock()
manager._publish(event, PublishFrom.TASK_PIPELINE)
manager.stop_listen.assert_called_once()
def test_publish_non_stop_event_no_stop_listen(mocker):
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
manager._q = mocker.MagicMock()
manager.stop_listen = mocker.MagicMock()
manager._is_stopped = mocker.MagicMock(return_value=False)
non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent)
manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE)
manager.stop_listen.assert_not_called()

View File

@@ -1,297 +0,0 @@
"""
Unit tests for PipelineRunner behavior.
Asserts correct event handling, error propagation, and user invocation logic.
Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies.
Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities.
"""
"""Unit tests for PipelineRunner behavior.
This module validates core control-flow outcomes for
``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph
initialization guards, invoke-source to user-source resolution, and failed-run
event handling. Invariants asserted here include strict graph-config
validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing
error paths driven by ``GraphRunFailedEvent`` through mocked collaborators.
Primary collaborators include ``PipelineRunner``,
``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``,
``UserFrom``, and patched DB/runtime dependencies used by the runner.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.app.apps.pipeline.pipeline_runner as module
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from dify_graph.graph_events import GraphRunFailedEvent
def _build_app_generate_entity() -> SimpleNamespace:
app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant")
return SimpleNamespace(
app_config=app_config,
invoke_from=InvokeFrom.WEB_APP,
user_id="user",
trace_manager=MagicMock(),
inputs={"input1": "v1"},
files=[],
workflow_execution_id="run",
document_id="doc",
original_document_id=None,
batch="batch",
dataset_id="ds",
datasource_type="local_file",
datasource_info={"name": "file"},
start_node_id="start",
call_depth=0,
single_iteration_run=None,
single_loop_run=None,
)
@pytest.fixture
def runner():
app_generate_entity = _build_app_generate_entity()
queue_manager = MagicMock()
variable_loader = MagicMock()
workflow = MagicMock()
workflow_execution_repository = MagicMock()
workflow_node_execution_repository = MagicMock()
return PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=queue_manager,
variable_loader=variable_loader,
workflow=workflow,
system_user_id="sys",
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
def test_get_app_id(runner):
assert runner._get_app_id() == "pipe"
def test_get_workflow_returns_workflow(mocker, runner):
pipeline = MagicMock(tenant_id="tenant", id="pipe")
workflow = MagicMock(id="wf")
query = MagicMock()
query.where.return_value.first.return_value = workflow
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
assert result == workflow
def test_init_rag_pipeline_graph_invalid_config(mocker, runner):
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={})
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
workflow.graph_dict = {"nodes": "bad", "edges": []}
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
workflow.graph_dict = {"nodes": [], "edges": "bad"}
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
def test_init_rag_pipeline_graph_not_found(mocker, runner):
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []})
mocker.patch.object(module.Graph, "init", return_value=None)
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
def test_update_document_status_on_failure(mocker, runner):
document = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = document
session = MagicMock()
session.query.return_value = query
mocker.patch.object(module.db, "session", session)
event = GraphRunFailedEvent(error="boom")
runner._update_document_status(event, document_id="doc", dataset_id="ds")
assert document.indexing_status == "error"
assert document.error == "boom"
session.commit.assert_called_once()
def test_run_pipeline_not_found(mocker):
app_generate_entity = _build_app_generate_entity()
app_generate_entity.invoke_from = InvokeFrom.WEB_APP
app_generate_entity.single_iteration_run = None
app_generate_entity.single_loop_run = None
query = MagicMock()
query.where.return_value.first.return_value = None
session = MagicMock()
session.query.return_value = query
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=MagicMock(),
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
with pytest.raises(ValueError):
runner.run()
def test_run_workflow_not_initialized(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
session = MagicMock()
session.query.return_value = query_pipeline
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=MagicMock(),
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
runner.get_workflow = MagicMock(return_value=None)
with pytest.raises(ValueError):
runner.run()
def test_run_single_iteration_path(mocker):
app_generate_entity = _build_app_generate_entity()
app_generate_entity.single_iteration_run = MagicMock()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
query_end_user = MagicMock()
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
session = MagicMock()
session.query.side_effect = [query_end_user, query_pipeline]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=MagicMock(),
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
runner.get_workflow = MagicMock(
return_value=MagicMock(
id="wf",
tenant_id="tenant",
app_id="pipe",
graph_dict={},
type="rag-pipeline",
version="v1",
)
)
runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state"))
runner._update_document_status = MagicMock()
runner._handle_event = MagicMock()
workflow_entry = MagicMock()
workflow_entry.graph_engine = MagicMock()
workflow_entry.run.return_value = [MagicMock()]
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
runner.run()
runner._prepare_single_node_execution.assert_called_once()
runner._handle_event.assert_called()
def test_run_normal_path_builds_graph(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
query_end_user = MagicMock()
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
session = MagicMock()
session.query.side_effect = [query_end_user, query_pipeline]
mocker.patch.object(module.db, "session", session)
workflow = MagicMock(
id="wf",
tenant_id="tenant",
app_id="pipe",
graph_dict={"nodes": [], "edges": []},
environment_variables=[],
rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}],
type="rag-pipeline",
version="v1",
)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
runner.get_workflow = MagicMock(return_value=workflow)
runner._init_rag_pipeline_graph = MagicMock(return_value="graph")
runner._update_document_status = MagicMock()
runner._handle_event = MagicMock()
mocker.patch.object(
module.RAGPipelineVariable,
"model_validate",
return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"),
)
mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
workflow_entry = MagicMock()
workflow_entry.graph_engine = MagicMock()
workflow_entry.run.return_value = []
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
runner.run()
runner._init_rag_pipeline_graph.assert_called_once()

View File

@@ -1,5 +1,3 @@
from unittest.mock import MagicMock
import pytest
from core.app.apps.base_app_generator import BaseAppGenerator
@@ -368,132 +366,3 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default():
)
assert result is None
class TestBaseAppGeneratorExtras:
def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch):
base_app_generator = BaseAppGenerator()
variables = [
VariableEntity(
variable="file",
label="file",
type=VariableEntityType.FILE,
required=False,
allowed_file_types=[],
allowed_file_extensions=[],
allowed_file_upload_methods=[],
),
VariableEntity(
variable="file_list",
label="file_list",
type=VariableEntityType.FILE_LIST,
required=False,
allowed_file_types=[],
allowed_file_extensions=[],
allowed_file_upload_methods=[],
),
VariableEntity(
variable="json",
label="json",
type=VariableEntityType.JSON_OBJECT,
required=False,
),
]
monkeypatch.setattr(
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
)
monkeypatch.setattr(
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
lambda mappings, tenant_id, config: ["file-1", "file-2"],
)
user_inputs = {
"file": {"id": "file-id"},
"file_list": [{"id": "file-1"}, {"id": "file-2"}],
"json": {"key": "value"},
}
prepared = base_app_generator._prepare_user_inputs(
user_inputs=user_inputs,
variables=variables,
tenant_id="tenant-id",
)
assert prepared["file"] == "file-object"
assert prepared["file_list"] == ["file-1", "file-2"]
assert prepared["json"] == {"key": "value"}
def test_prepare_user_inputs_rejects_invalid_dict_inputs(self):
base_app_generator = BaseAppGenerator()
variables = [
VariableEntity(
variable="text",
label="text",
type=VariableEntityType.TEXT_INPUT,
required=False,
)
]
with pytest.raises(ValueError, match="must be a string"):
base_app_generator._prepare_user_inputs(
user_inputs={"text": {"unexpected": "dict"}},
variables=variables,
tenant_id="tenant-id",
)
def test_prepare_user_inputs_rejects_invalid_list_inputs(self):
base_app_generator = BaseAppGenerator()
variables = [
VariableEntity(
variable="text",
label="text",
type=VariableEntityType.TEXT_INPUT,
required=False,
)
]
with pytest.raises(ValueError, match="must be a string"):
base_app_generator._prepare_user_inputs(
user_inputs={"text": [{"unexpected": "dict"}]},
variables=variables,
tenant_id="tenant-id",
)
def test_convert_to_event_stream(self):
base_app_generator = BaseAppGenerator()
assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True}
def _gen():
yield {"delta": "hi"}
yield "ping"
converted = list(base_app_generator.convert_to_event_stream(_gen()))
assert converted[0].startswith("data: ")
assert "\n\n" in converted[0]
assert converted[1] == "event: ping\n\n"
def test_get_draft_var_saver_factory_debugger(self):
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.enums import NodeType
from models import Account
base_app_generator = BaseAppGenerator()
account = Account(name="Tester", email="tester@example.com")
account.id = "account-id"
account.tenant_id = "tenant-id"
factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account)
saver = factory(
session=MagicMock(),
app_id="app-id",
node_id="node-id",
node_type=NodeType.START,
node_execution_id="node-exec-id",
)
assert saver is not None

View File

@@ -1,61 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueErrorEvent
class DummyQueueManager(AppQueueManager):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.published = []
def _publish(self, event, pub_from):
self.published.append((event, pub_from))
class TestBaseAppQueueManager:
def test_init_requires_user_id(self):
with pytest.raises(ValueError):
DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API)
def test_publish_error_records_event(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE)
assert isinstance(manager.published[0][0], QueueErrorEvent)
def test_set_stop_flag_checks_user(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.get.return_value = b"end-user-u1"
AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1")
mock_redis.setex.assert_called_once()
def test_set_stop_flag_no_user_check(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
AppQueueManager.set_stop_flag_no_user_check(task_id="t1")
mock_redis.setex.assert_called_once()
def test_is_stopped_reads_cache(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
mock_redis.get.return_value = b"1"
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
assert manager._is_stopped() is True
def test_check_for_sqlalchemy_models_raises(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
bad = SimpleNamespace(_sa_instance_state=True)
with pytest.raises(TypeError):
manager._check_for_sqlalchemy_models(bad)

View File

@@ -1,442 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
PromptTemplateEntity,
)
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from models.model import AppMode
class _DummyParameterRule:
def __init__(self, name: str, use_template: str | None = None) -> None:
self.name = name
self.use_template = use_template
class _QueueRecorder:
def __init__(self) -> None:
self.events: list[object] = []
def publish(self, event, pub_from):
_ = pub_from
self.events.append(event)
class TestAppRunner:
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch):
runner = AppRunner()
model_schema = SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
parameter_rules=[_DummyParameterRule("max_tokens")],
)
model_config = SimpleNamespace(
provider_model_bundle=object(),
model="mock",
model_schema=model_schema,
parameters={"max_tokens": 30},
)
monkeypatch.setattr(
"core.app.apps.base_app_runner.ModelInstance",
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80),
)
runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")])
assert model_config.parameters["max_tokens"] == 20
def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch):
runner = AppRunner()
model_schema = SimpleNamespace(
model_properties={},
parameter_rules=[_DummyParameterRule("max_tokens")],
)
model_config = SimpleNamespace(
provider_model_bundle=object(),
model="mock",
model_schema=model_schema,
parameters={"max_tokens": 30},
)
monkeypatch.setattr(
"core.app.apps.base_app_runner.ModelInstance",
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10),
)
assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1
def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True)
monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None)
runner.direct_output(
queue_manager=queue,
app_generate_entity=app_generate_entity,
prompt_messages=[],
text="hi",
stream=True,
)
assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
def test_handle_invoke_result_direct_publishes_end_event(self):
runner = AppRunner()
queue = _QueueRecorder()
llm_result = LLMResult(
model="mock",
prompt_messages=[],
message=AssistantPromptMessage(content="done"),
usage=LLMUsage.empty_usage(),
)
runner._handle_invoke_result(
invoke_result=llm_result,
queue_manager=queue,
stream=False,
)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
def test_handle_invoke_result_invalid_type_raises(self):
runner = AppRunner()
queue = _QueueRecorder()
with pytest.raises(NotImplementedError):
runner._handle_invoke_result(
invoke_result=["unexpected"],
queue_manager=queue,
stream=True,
)
def test_organize_prompt_messages_simple_template(self, monkeypatch):
runner = AppRunner()
model_config = SimpleNamespace(mode="chat", stop=["STOP"])
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="hello",
)
monkeypatch.setattr(
"core.app.apps.base_app_runner.SimplePromptTransform.get_prompt",
lambda self, **kwargs: (["simple-message"], ["simple-stop"]),
)
prompt_messages, stop = runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs={},
files=[],
query="q",
)
assert prompt_messages == ["simple-message"]
assert stop == ["simple-stop"]
def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch):
runner = AppRunner()
model_config = SimpleNamespace(mode="completion", stop=["<END>"])
captured: dict[str, object] = {}
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="answer",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"),
),
)
def _fake_advanced_prompt(self, **kwargs):
captured.update(kwargs)
return ["advanced-completion-message"]
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
prompt_messages, stop = runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs={},
files=[],
query="q",
)
assert prompt_messages == ["advanced-completion-message"]
assert stop == ["<END>"]
memory_config = captured["memory_config"]
assert memory_config.role_prefix.user == "U"
assert memory_config.role_prefix.assistant == "A"
def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch):
runner = AppRunner()
model_config = SimpleNamespace(mode="chat", stop=["<END>"])
captured: dict[str, object] = {}
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT),
]
),
)
def _fake_advanced_prompt(self, **kwargs):
captured.update(kwargs)
return ["advanced-chat-message"]
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
prompt_messages, stop = runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs={},
files=[],
query="q",
)
assert prompt_messages == ["advanced-chat-message"]
assert stop == ["<END>"]
assert len(captured["prompt_template"]) == 2
def test_organize_prompt_messages_advanced_missing_templates_raise(self):
runner = AppRunner()
with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"):
runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=SimpleNamespace(mode="completion", stop=[]),
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
inputs={},
files=[],
)
with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"):
runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=SimpleNamespace(mode="chat", stop=[]),
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
inputs={},
files=[],
)
def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
warning_logger = MagicMock()
monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger)
image_content = ImagePromptMessageContent(
url="https://example.com/image.png", format="png", mime_type="image/png"
)
def _stream():
yield LLMResultChunk(
model="stream-model",
prompt_messages=[AssistantPromptMessage(content="prompt")],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage.model_construct(
content=[
"a",
TextPromptMessageContent(data="b"),
SimpleNamespace(data="c"),
image_content,
]
),
),
)
runner._handle_invoke_result(
invoke_result=_stream(),
queue_manager=queue,
stream=True,
agent=False,
)
assert isinstance(queue.events[0], QueueLLMChunkEvent)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
assert queue.events[-1].llm_result.message.content == "abc"
warning_logger.assert_called_once()
def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
exception_logger = MagicMock()
monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger)
monkeypatch.setattr(
runner,
"_handle_multimodal_image_content",
MagicMock(side_effect=RuntimeError("failed to save image")),
)
usage = LLMUsage.empty_usage()
def _stream():
yield LLMResultChunk(
model="agent-model",
prompt_messages=[AssistantPromptMessage(content="prompt")],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
ImagePromptMessageContent(
url="https://example.com/image.png",
format="png",
mime_type="image/png",
),
TextPromptMessageContent(data="done"),
]
),
usage=usage,
),
)
runner._handle_invoke_result_stream(
invoke_result=_stream(),
queue_manager=queue,
agent=True,
message_id="message-id",
user_id="user-id",
tenant_id="tenant-id",
)
assert isinstance(queue.events[0], QueueAgentMessageEvent)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
assert queue.events[-1].llm_result.usage == usage
exception_logger.assert_called_once()
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch):
runner = AppRunner()
class _ToggleBool:
def __init__(self, values: list[bool]):
self._values = values
self._index = 0
def __bool__(self):
value = self._values[min(self._index, len(self._values) - 1)]
self._index += 1
return value
content = SimpleNamespace(
url=_ToggleBool([False, False]),
base64_data=_ToggleBool([True, False]),
mime_type="image/png",
)
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock())
monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock())
monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session))
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock())
runner._handle_multimodal_image_content(
content=content,
message_id="message-id",
user_id="user-id",
tenant_id="tenant-id",
queue_manager=queue_manager,
)
db_session.add.assert_not_called()
queue_manager.publish.assert_not_called()
def test_check_hosting_moderation_direct_output_called(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
app_generate_entity = SimpleNamespace(stream=False)
monkeypatch.setattr(
"core.app.apps.base_app_runner.HostingModerationFeature.check",
lambda self, application_generate_entity, prompt_messages: True,
)
direct_output = MagicMock()
monkeypatch.setattr(runner, "direct_output", direct_output)
result = runner.check_hosting_moderation(
application_generate_entity=app_generate_entity,
queue_manager=queue,
prompt_messages=[],
)
assert result is True
assert direct_output.called
def test_fill_in_inputs_from_external_data_tools(self, monkeypatch):
runner = AppRunner()
monkeypatch.setattr(
"core.app.apps.base_app_runner.ExternalDataFetch.fetch",
lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"},
)
result = runner.fill_in_inputs_from_external_data_tools(
tenant_id="tenant",
app_id="app",
external_data_tools=[],
inputs={},
query="q",
)
assert result == {"foo": "bar"}
def test_moderation_for_inputs_returns_result(self, monkeypatch):
runner = AppRunner()
monkeypatch.setattr(
"core.app.apps.base_app_runner.InputModeration.check",
lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""),
)
app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None)
result = runner.moderation_for_inputs(
app_id="app",
tenant_id="tenant",
app_generate_entity=app_generate_entity,
inputs={},
query="q",
message_id="msg",
)
assert result == (True, {}, "")
def test_query_app_annotations_to_reply(self, monkeypatch):
runner = AppRunner()
monkeypatch.setattr(
"core.app.apps.base_app_runner.AnnotationReplyFeature.query",
lambda self, app_record, message, query, user_id, invoke_from: "reply",
)
response = runner.query_app_annotations_to_reply(
app_record=SimpleNamespace(),
message=SimpleNamespace(),
query="hello",
user_id="user",
invoke_from=InvokeFrom.WEB_APP,
)
assert response == "reply"

View File

@@ -1,7 +0,0 @@
from core.app.apps.exc import GenerateTaskStoppedError
class TestAppsExceptions:
def test_generate_task_stopped_error(self):
err = GenerateTaskStoppedError("stopped")
assert str(err) == "stopped"

View File

@@ -13,11 +13,9 @@ from core.app.app_config.entities import (
PromptTemplateEntity,
)
from core.app.apps import message_based_app_generator
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from models.model import AppMode, Conversation, Message
from services.errors.app_model_config import AppModelConfigBrokenError
class DummyModelConf:
@@ -127,55 +125,3 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity():
assert entity.conversation_id == "generated-conversation-id"
assert entity.is_new_conversation is True
assert conversation.id == "generated-conversation-id"
class TestMessageBasedAppGeneratorExtras:
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
generator = MessageBasedAppGenerator()
class _Pipeline:
def __init__(self, **kwargs) -> None:
_ = kwargs
def process(self):
raise ValueError("I/O operation on closed file.")
monkeypatch.setattr(
"core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline",
_Pipeline,
)
with pytest.raises(GenerateTaskStoppedError):
generator._handle_response(
application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)),
queue_manager=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
user=SimpleNamespace(),
stream=False,
)
def test_get_app_model_config_requires_valid_config(self, monkeypatch):
generator = MessageBasedAppGenerator()
app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None)
with pytest.raises(AppModelConfigBrokenError):
generator._get_app_model_config(app_model, conversation=None)
conversation = SimpleNamespace(app_model_config_id="missing-id")
monkeypatch.setattr(
message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None))
)
with pytest.raises(AppModelConfigBrokenError):
generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation)
def test_get_conversation_introduction_handles_missing_inputs(self):
app_config = _make_app_config(AppMode.CHAT)
app_config.additional_features.opening_statement = "Hello {{name}}"
entity = _make_chat_generate_entity(app_config)
entity.inputs = {}
generator = MessageBasedAppGenerator()
assert generator._get_conversation_introduction(entity) == "Hello {name}"

View File

@@ -1,65 +0,0 @@
from unittest.mock import Mock, patch
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent
class TestMessageBasedAppQueueManager:
def test_publish_stops_on_terminal_events(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = MessageBasedAppQueueManager(
task_id="t1",
user_id="u1",
invoke_from=InvokeFrom.SERVICE_API,
conversation_id="c1",
app_mode="chat",
message_id="m1",
)
manager.stop_listen = Mock()
manager._is_stopped = Mock(return_value=False)
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock())
manager.stop_listen.assert_called_once()
def test_publish_raises_when_stopped(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = MessageBasedAppQueueManager(
task_id="t1",
user_id="u1",
invoke_from=InvokeFrom.SERVICE_API,
conversation_id="c1",
app_mode="chat",
message_id="m1",
)
manager._is_stopped = Mock(return_value=True)
with pytest.raises(GenerateTaskStoppedError):
manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER)
def test_publish_enqueues_message_end(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = MessageBasedAppQueueManager(
task_id="t1",
user_id="u1",
invoke_from=InvokeFrom.SERVICE_API,
conversation_id="c1",
app_mode="chat",
message_id="m1",
)
manager._is_stopped = Mock(return_value=False)
manager.stop_listen = Mock()
manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE)
assert manager._q.qsize() == 1

View File

@@ -1,29 +0,0 @@
from unittest.mock import Mock, patch
from core.app.apps.message_generator import MessageGenerator
from models.model import AppMode
class TestMessageGenerator:
def test_get_response_topic(self):
channel = Mock()
channel.topic.return_value = "topic"
with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel):
topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1")
assert topic == "topic"
expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1")
channel.topic.assert_called_once_with(expected_key)
def test_retrieve_events_passes_arguments(self):
with (
patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"),
patch(
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
) as mock_stream,
):
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
assert events == [{"event": "ping"}]
mock_stream.assert_called_once()

View File

@@ -6,7 +6,6 @@ import queue
import pytest
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events
from core.app.entities.task_entities import StreamEvent
from models.model import AppMode
@@ -79,30 +78,3 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
with pytest.raises(StopIteration):
next(generator)
def test_normalize_terminal_events_defaults():
assert _normalize_terminal_events(None) == {
StreamEvent.WORKFLOW_FINISHED.value,
StreamEvent.WORKFLOW_PAUSED.value,
}
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
topic = FakeTopic()
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
def fake_time():
return times.pop(0)
monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time)
generator = stream_topic_events(
topic=topic,
idle_timeout=10.0,
ping_interval=1.0,
)
assert next(generator) == StreamEvent.PING.value
# next receive yields None -> ping interval triggers
assert next(generator) == StreamEvent.PING.value

View File

@@ -1,261 +0,0 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
import pytest
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueLoopCompletedEvent,
QueueTextChunkEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeType
from dify_graph.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
)
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
class TestWorkflowBasedAppRunner:
def test_resolve_user_from(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT
assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT
assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER
def test_init_graph_validates_graph_structure(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
with pytest.raises(ValueError, match="nodes or edges not found"):
runner._init_graph(
graph_config={},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
with pytest.raises(ValueError, match="nodes in workflow graph must be a list"):
runner._init_graph(
graph_config={"nodes": {}, "edges": []},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
with pytest.raises(ValueError, match="edges in workflow graph must be a list"):
runner._init_graph(
graph_config={"nodes": [], "edges": {}},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
def test_prepare_single_node_execution_requires_run(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
workflow = SimpleNamespace(environment_variables=[], graph_dict={})
with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"):
runner._prepare_single_node_execution(workflow, None, None)
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
graph_config = {
"nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}],
"edges": [],
}
workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.Graph.init",
lambda **kwargs: SimpleNamespace(),
)
class _NodeCls:
@staticmethod
def extract_variable_selector_to_variable_mapping(graph_config, config):
return {}
from core.app.apps import workflow_app_runner
monkeypatch.setitem(
workflow_app_runner.NODE_TYPE_CLASSES_MAPPING,
NodeType.START,
{"1": _NodeCls},
)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.load_into_variable_pool",
lambda **kwargs: None,
)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool",
lambda **kwargs: None,
)
graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id="node-1",
user_inputs={},
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
assert graph is not None
assert variable_pool is graph_runtime_state.variable_pool
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
published: list[object] = []
class _QueueManager:
def publish(self, event, publish_from):
published.append((event, publish_from))
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
graph_runtime_state.register_paused_node("node-1")
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
emails: list[dict] = []
class _Dispatch:
def apply_async(self, *, kwargs, queue):
emails.append({"kwargs": kwargs, "queue": queue})
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.dispatch_human_input_email_task",
_Dispatch(),
)
reason = HumanInputRequired(
form_id="form",
form_content="content",
node_id="node-1",
node_title="Node",
)
runner._handle_event(workflow_entry, GraphRunStartedEvent())
runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True}))
runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={}))
assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published)
assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published)
paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent))
assert paused_event.paused_nodes == ["node-1"]
assert emails
def test_handle_node_events_publishes_queue_events(self):
published: list[object] = []
class _QueueManager:
def publish(self, event, publish_from):
published.append(event)
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
runner._handle_event(
workflow_entry,
NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=NodeType.START,
node_title="Start",
start_at=datetime.utcnow(),
),
)
runner._handle_event(
workflow_entry,
NodeRunStreamChunkEvent(
id="exec",
node_id="node",
node_type=NodeType.START,
selector=["node", "text"],
chunk="hi",
is_final=False,
),
)
runner._handle_event(
workflow_entry,
NodeRunAgentLogEvent(
id="exec",
node_id="node",
node_type=NodeType.START,
message_id="msg",
label="label",
node_execution_id="exec",
parent_id=None,
error=None,
status="done",
data={},
metadata={},
),
)
runner._handle_event(
workflow_entry,
NodeRunIterationSucceededEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="Iter",
start_at=datetime.utcnow(),
inputs={},
outputs={"ok": True},
metadata={},
steps=1,
),
)
runner._handle_event(
workflow_entry,
NodeRunLoopFailedEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="Loop",
start_at=datetime.utcnow(),
inputs={},
outputs={},
metadata={},
steps=1,
error="boom",
),
)
assert any(isinstance(event, QueueTextChunkEvent) for event in published)
assert any(isinstance(event, QueueAgentLogEvent) for event in published)
assert any(isinstance(event, QueueIterationCompletedEvent) for event in published)
assert any(isinstance(event, QueueLoopCompletedEvent) for event in published)

View File

@@ -1,61 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from models.model import AppMode
class TestWorkflowAppConfigManager:
def test_get_app_config(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
workflow = SimpleNamespace(id="wf-1", features_dict={})
with (
patch(
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
return_value=None,
),
patch(
"core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert",
return_value=[],
),
):
app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow)
assert app_config.workflow_id == "wf-1"
assert app_config.app_mode == AppMode.WORKFLOW
def test_config_validate_filters_keys(self):
def _add_key(key, value):
def _inner(*args, **kwargs):
# Support both positional and keyword arguments for config
if "config" in kwargs:
config = kwargs["config"]
elif len(args) > 0:
config = args[0]
else:
config = {}
config[key] = value
return config, [key]
return _inner
with (
patch(
"core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=_add_key("file_upload", 1),
),
patch(
"core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=_add_key("text_to_speech", 2),
),
patch(
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=_add_key("sensitive_word_avoidance", 3),
),
):
filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={})
assert filtered["file_upload"] == 1
assert filtered["text_to_speech"] == 2
assert filtered["sensitive_word_avoidance"] == 3

View File

@@ -1,188 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TestWorkflowAppGeneratorValidation:
def test_should_prepare_user_inputs(self):
generator = WorkflowAppGenerator()
assert generator._should_prepare_user_inputs({}) is True
assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False
def test_single_iteration_generate_validates_args(self):
generator = WorkflowAppGenerator()
with pytest.raises(ValueError, match="node_id is required"):
generator.single_iteration_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="",
user=SimpleNamespace(),
args={"inputs": {}},
streaming=False,
)
with pytest.raises(ValueError, match="inputs is required"):
generator.single_iteration_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="node",
user=SimpleNamespace(),
args={},
streaming=False,
)
def test_single_loop_generate_validates_args(self):
generator = WorkflowAppGenerator()
with pytest.raises(ValueError, match="node_id is required"):
generator.single_loop_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="",
user=SimpleNamespace(),
args=SimpleNamespace(inputs={}),
streaming=False,
)
with pytest.raises(ValueError, match="inputs is required"):
generator.single_loop_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="node",
user=SimpleNamespace(),
args=SimpleNamespace(inputs=None),
streaming=False,
)
class TestWorkflowAppGeneratorHandleResponse:
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
generator = WorkflowAppGenerator()
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_execution_id="run-id",
call_depth=0,
)
class _Pipeline:
def __init__(self, **kwargs) -> None:
_ = kwargs
def process(self):
raise ValueError("I/O operation on closed file.")
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline",
_Pipeline,
)
with pytest.raises(GenerateTaskStoppedError):
generator._handle_response(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(),
queue_manager=SimpleNamespace(),
user=SimpleNamespace(),
draft_var_saver_factory=lambda **kwargs: None,
stream=False,
)
class TestWorkflowAppGeneratorGenerate:
def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch):
generator = WorkflowAppGenerator()
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config",
lambda app_model, workflow: app_config,
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.FileUploadConfigManager.convert",
lambda features_dict, is_vision=False: None,
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.file_factory.build_from_mappings",
lambda **kwargs: [],
)
DummyTraceQueueManager = type(
"_DummyTraceQueueManager",
(TraceQueueManager,),
{
"__init__": lambda self, app_id=None, user_id=None: (
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
)
},
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.TraceQueueManager",
DummyTraceQueueManager,
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
lambda **kwargs: SimpleNamespace(),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
lambda **kwargs: SimpleNamespace(),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.db",
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.sessionmaker",
lambda **kwargs: SimpleNamespace(),
)
prepare_inputs = pytest.fail
monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs())
monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True})
result = generator.generate(
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
workflow=SimpleNamespace(features_dict={}),
user=SimpleNamespace(id="user", session_id="session"),
args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True},
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
call_depth=0,
)
assert result == {"ok": True}

View File

@@ -1,33 +0,0 @@
from __future__ import annotations
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent
class TestWorkflowAppQueueManager:
def test_publish_stop_events_trigger_stop(self):
manager = WorkflowAppQueueManager(
task_id="task",
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
app_mode="workflow",
)
manager._is_stopped = lambda: True
with pytest.raises(GenerateTaskStoppedError):
manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER)
def test_publish_non_stop_event_does_not_raise(self):
manager = WorkflowAppQueueManager(
task_id="task",
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
app_mode="workflow",
)
manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)

View File

@@ -1,9 +0,0 @@
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
class TestWorkflowErrors:
def test_workflow_paused_in_blocking_mode_error_attributes(self):
err = WorkflowPausedInBlockingModeError()
assert err.error_code == "workflow_paused_in_blocking_mode"
assert err.code == 400
assert "blocking response mode" in err.description

View File

@@ -1,133 +0,0 @@
from collections.abc import Generator
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.entities.task_entities import (
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
class TestWorkflowGenerateResponseConverter:
def test_blocking_full_response(self):
blocking = WorkflowAppBlockingResponse(
task_id="t1",
workflow_run_id="r1",
data=WorkflowAppBlockingResponse.Data(
id="exec-1",
workflow_id="wf-1",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"ok": True},
error=None,
elapsed_time=1.2,
total_tokens=10,
total_steps=2,
created_at=1,
finished_at=2,
),
)
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert response["workflow_run_id"] == "r1"
def test_stream_simple_response_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeStartStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
created_at=1,
),
)
node_finish = NodeFinishStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeFinishStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
elapsed_time=0.1,
created_at=1,
finished_at=2,
),
)
def stream() -> Generator[WorkflowAppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1"))
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start)
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish)
yield WorkflowAppStreamResponse(
workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom"))
)
converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert converted[0] == "ping"
assert converted[1]["event"] == "node_started"
assert converted[2]["event"] == "node_finished"
assert converted[3]["event"] == "error"
def test_convert_stream_simple_response_handles_ping_and_nodes(self):
def _gen():
yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task"))
yield WorkflowAppStreamResponse(
workflow_run_id="run",
stream_response=NodeStartStreamResponse(
task_id="task",
workflow_run_id="run",
data=NodeStartStreamResponse.Data(
id="node-exec",
node_id="node",
node_type="start",
title="Start",
index=1,
created_at=1,
),
),
)
yield WorkflowAppStreamResponse(
workflow_run_id="run",
stream_response=NodeFinishStreamResponse(
task_id="task",
workflow_run_id="run",
data=NodeFinishStreamResponse.Data(
id="node-exec",
node_id="node",
node_type="start",
title="Start",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
created_at=1,
finished_at=2,
elapsed_time=1.0,
error=None,
),
),
)
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen()))
assert chunks[0] == "ping"
assert chunks[1]["event"] == "node_started"
assert chunks[2]["event"] == "node_finished"
def test_convert_stream_full_response_handles_error(self):
def _gen():
yield WorkflowAppStreamResponse(
workflow_run_id="run",
stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")),
)
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen()))
assert chunks[0]["event"] == "error"

View File

@@ -1,868 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from dify_graph.enums import NodeType, WorkflowExecutionStatus
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from models.enums import CreatorUserRole
from models.model import AppMode, EndUser
def _make_pipeline():
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
user = SimpleNamespace(id="user", session_id="session")
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
user=user,
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
return pipeline
class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_handles_pause(self):
pipeline = _make_pipeline()
def _gen():
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert response.data.status == WorkflowExecutionStatus.PAUSED
def test_to_blocking_response_handles_finish(self):
pipeline = _make_pipeline()
def _gen():
yield WorkflowFinishStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowFinishStreamResponse.Data(
id="run",
workflow_id="workflow-id",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"ok": True},
error=None,
elapsed_time=1.0,
total_tokens=5,
total_steps=2,
created_at=1,
finished_at=2,
),
)
response = pipeline._to_blocking_response(_gen())
assert response.data.outputs == {"ok": True}
def test_listen_audio_msg_returns_audio_stream(self):
pipeline = _make_pipeline()
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
assert isinstance(response, MessageAudioStreamResponse)
def test_handle_ping_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
assert isinstance(responses[0], PingStreamResponse)
def test_handle_error_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
assert isinstance(responses[0], ValueError)
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None)
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
assert pipeline._workflow_execution_id == "run-id"
assert responses == ["started"]
def test_handle_node_succeeded_event_saves_output(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
pipeline._save_output_for_event = lambda event, node_execution_id: None
pipeline._workflow_execution_id = "run-id"
event = QueueNodeSucceededEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.START,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
)
responses = list(pipeline._handle_node_succeeded_event(event))
assert responses == ["done"]
def test_handle_workflow_failed_event_yields_error(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
responses = list(
pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1))
)
assert responses[0] == "finish"
def test_handle_text_chunk_event_publishes_tts(self):
pipeline = _make_pipeline()
published: list[object] = []
class _Publisher:
def publish(self, message):
published.append(message)
event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"])
queue_message = SimpleNamespace(event=event)
responses = list(
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
)
assert responses[0].data.text == "hi"
assert published == [queue_message]
def test_dispatch_event_handles_node_failed(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
event = QueueNodeFailedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.START,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
assert list(pipeline._dispatch_event(event)) == ["done"]
def test_handle_stop_event_yields_finish(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
responses = list(
pipeline._handle_workflow_failed_and_stop_events(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
)
)
assert responses == ["finish"]
def test_save_workflow_app_log_created_from(self):
pipeline = _make_pipeline()
pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API
pipeline._user_id = "user"
added: list[object] = []
class _Session:
def add(self, item):
added.append(item)
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert added
def test_iteration_loop_and_human_input_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter"
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next"
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done"
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop"
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log"
iter_start = QueueIterationStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
iter_next = QueueIterationNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
iter_done = QueueIterationCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_start = QueueLoopStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_next = QueueLoopNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
loop_done = QueueLoopCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
filled_event = QueueHumanInputFormFilledEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="title",
rendered_content="content",
action_id="action",
action_text="action",
)
timeout_event = QueueHumanInputFormTimeoutEvent(
node_id="node",
node_type=NodeType.LLM,
node_title="title",
expiration_time=datetime.utcnow(),
)
agent_event = QueueAgentLogEvent(
id="log",
label="label",
node_execution_id="exec",
parent_id=None,
error=None,
status="done",
data={},
metadata={},
node_id="node",
)
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"]
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"]
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"]
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"]
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"]
def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
}
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
class _Publisher:
def __init__(self, *args, **kwargs):
self.calls = 0
def check_and_get_audio(self):
self.calls += 1
if self.calls == 1:
return AudioTrunk(status="stream", audio="data")
if self.calls == 2:
return None
return AudioTrunk(status="finish", audio="")
def publish(self, message):
return None
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
_Publisher,
)
responses = list(pipeline._wrapper_process_stream_response())
assert any(isinstance(item, MessageAudioStreamResponse) for item in responses)
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
def test_init_with_end_user_sets_role_and_system_user(self):
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="end-user-id",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None)
end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id")
end_user.id = "end-user-id"
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=end_user,
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
assert pipeline._created_by_role == CreatorUserRole.END_USER
assert pipeline._workflow_system_variables.user_id == "session-id"
def test_process_returns_stream_and_blocking_variants(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.stream = True
pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
stream_response = list(pipeline.process())
assert len(stream_response) == 1
assert stream_response[0].workflow_run_id is None
pipeline._base_task_pipeline.stream = False
pipeline._wrapper_process_stream_response = lambda **kwargs: iter(
[
WorkflowFinishStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowFinishStreamResponse.Data(
id="run-id",
workflow_id="workflow-id",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={},
error=None,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
created_at=1,
finished_at=2,
),
)
]
)
blocking_response = pipeline.process()
assert blocking_response.workflow_run_id == "run-id"
def test_to_blocking_response_handles_error_and_unexpected_end(self):
pipeline = _make_pipeline()
def _error_gen():
yield ErrorStreamResponse(task_id="task", err=ValueError("boom"))
with pytest.raises(ValueError, match="boom"):
pipeline._to_blocking_response(_error_gen())
def _unexpected_gen():
yield PingStreamResponse(task_id="task")
with pytest.raises(ValueError, match="queue listening stopped unexpectedly"):
pipeline._to_blocking_response(_unexpected_gen())
def test_to_stream_response_tracks_workflow_run_id(self):
pipeline = _make_pipeline()
def _gen():
yield WorkflowStartStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowStartStreamResponse.Data(
id="run-id",
workflow_id="workflow-id",
inputs={},
created_at=1,
),
)
yield PingStreamResponse(task_id="task")
stream_responses = list(pipeline._to_stream_response(_gen()))
assert stream_responses[0].workflow_run_id == "run-id"
assert stream_responses[1].workflow_run_id == "run-id"
def test_listen_audio_msg_returns_none_without_publisher(self):
pipeline = _make_pipeline()
assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None
def test_wrapper_process_stream_response_without_tts(self):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {}
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
responses = list(pipeline._wrapper_process_stream_response())
assert responses == [PingStreamResponse(task_id="task")]
def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
}
pipeline._process_stream_response = lambda **kwargs: iter([])
sleep_spy = []
class _Publisher:
def __init__(self, *args, **kwargs):
self.calls = 0
def check_and_get_audio(self):
self.calls += 1
if self.calls == 1:
return None
return AudioTrunk(status="finish", audio="")
def publish(self, message):
_ = message
time_values = iter([0.0, 0.0, 0.2])
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values))
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True)
)
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
_Publisher,
)
responses = list(pipeline._wrapper_process_stream_response())
assert sleep_spy
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
}
pipeline._process_stream_response = lambda **kwargs: iter([])
class _Publisher:
def __init__(self, *args, **kwargs):
self.called = False
def check_and_get_audio(self):
if not self.called:
self.called = True
raise RuntimeError("tts failure")
return AudioTrunk(status="finish", audio="")
def publish(self, message):
_ = message
logger_exception = []
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0)
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.logger.exception",
lambda *args, **kwargs: logger_exception.append((args, kwargs)),
)
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
_Publisher,
)
responses = list(pipeline._wrapper_process_stream_response())
assert logger_exception
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
def test_database_session_rolls_back_on_error(self, monkeypatch):
pipeline = _make_pipeline()
calls = {"commit": 0, "rollback": 0}
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def commit(self):
calls["commit"] += 1
def rollback(self):
calls["rollback"] += 1
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
with pytest.raises(RuntimeError, match="db error"):
with pipeline._database_session():
raise RuntimeError("db error")
assert calls["commit"] == 0
assert calls["rollback"] == 1
def test_node_retry_and_started_handlers_cover_none_and_value(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
retry_event = QueueNodeRetryEvent(
node_execution_id="exec",
node_id="node",
node_title="title",
node_type=NodeType.LLM,
node_run_index=1,
start_at=datetime.utcnow(),
provider_type="provider",
provider_id="provider-id",
error="error",
retry_index=1,
)
started_event = QueueNodeStartedEvent(
node_execution_id="exec",
node_id="node",
node_title="title",
node_type=NodeType.LLM,
node_run_index=1,
start_at=datetime.utcnow(),
provider_type="provider",
provider_id="provider-id",
)
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None
assert list(pipeline._handle_node_retry_event(retry_event)) == []
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry"
assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"]
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None
assert list(pipeline._handle_node_started_event(started_event)) == []
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started"
assert list(pipeline._handle_node_started_event(started_event)) == ["started"]
def test_handle_node_exception_event_saves_output(self):
pipeline = _make_pipeline()
saved_ids: list[str] = []
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id)
event = QueueNodeExceptionEvent(
node_execution_id="exec-id",
node_id="node",
node_type=NodeType.START,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="boom",
)
responses = list(pipeline._handle_node_failed_events(event))
assert responses == ["failed"]
assert saved_ids == ["exec-id"]
def test_success_partial_and_pause_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"]
assert list(
pipeline._handle_workflow_partial_success_event(
QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={})
)
) == ["finish"]
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [
"pause-a",
"pause-b",
]
pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"])
assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"]
def test_text_chunk_handler_returns_empty_when_text_missing(self):
pipeline = _make_pipeline()
event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None)
assert list(pipeline._handle_text_chunk_event(event)) == []
def test_dispatch_event_direct_failed_and_unhandled_paths(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"]
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"])
assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [
"workflow-failed"
]
assert list(pipeline._dispatch_event(SimpleNamespace())) == []
def test_process_stream_response_main_match_paths_and_cleanup(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[
SimpleNamespace(event=QueueWorkflowStartedEvent()),
SimpleNamespace(event=QueueTextChunkEvent(text="hello")),
SimpleNamespace(event=QueuePingEvent()),
SimpleNamespace(event=QueueErrorEvent(error="e")),
]
)
pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"])
pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"])
pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"])
pipeline._handle_error_event = lambda event, **kwargs: iter(["error"])
publisher_calls: list[object] = []
class _Publisher:
def publish(self, message):
publisher_calls.append(message)
responses = list(pipeline._process_stream_response(tts_publisher=_Publisher()))
assert responses == ["started", "text", "dispatched", "error"]
assert publisher_calls == [None]
def test_process_stream_response_break_paths(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))]
)
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"])
assert list(pipeline._process_stream_response()) == ["failed"]
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))]
)
pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"])
assert list(pipeline._process_stream_response()) == ["paused"]
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))]
)
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"])
assert list(pipeline._process_stream_response()) == ["stopped"]
def test_save_workflow_app_log_covers_invoke_from_variants(self):
pipeline = _make_pipeline()
pipeline._user_id = "user-id"
added: list[object] = []
class _Session:
def add(self, item):
added.append(item)
pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert added[-1].created_from == "installed-app"
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert added[-1].created_from == "web-app"
count_before = len(added)
pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert len(added) == count_before
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None)
assert len(added) == count_before
def test_save_output_for_event_writes_draft_variables(self, monkeypatch):
pipeline = _make_pipeline()
saver_calls: list[tuple[object, object]] = []
captured_factory_args: dict[str, object] = {}
class _Saver:
def save(self, process_data, outputs):
saver_calls.append((process_data, outputs))
def _factory(**kwargs):
captured_factory_args.update(kwargs)
return _Saver()
class _Begin:
def __enter__(self):
return None
def __exit__(self, exc_type, exc, tb):
return False
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return _Begin()
pipeline._draft_var_saver_factory = _factory
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
event = QueueNodeSucceededEvent(
node_execution_id="exec-id",
node_id="node-id",
node_type=NodeType.START,
in_loop_id="loop-id",
start_at=datetime.utcnow(),
process_data={"k": "v"},
outputs={"out": 1},
)
pipeline._save_output_for_event(event=event, node_execution_id="exec-id")
assert captured_factory_args["node_execution_id"] == "exec-id"
assert captured_factory_args["enclosing_node_id"] == "loop-id"
assert saver_calls == [({"k": "v"}, {"out": 1})]

View File

@@ -1,390 +0,0 @@
import base64
import queue
from unittest.mock import MagicMock
import pytest
from core.base.tts.app_generator_tts_publisher import (
AppGeneratorTTSPublisher,
AudioTrunk,
_invoice_tts,
_process_future,
)
# =========================
# Fixtures
# =========================
@pytest.fixture
def mock_model_instance(mocker):
model = mocker.MagicMock()
model.invoke_tts.return_value = [b"audio1", b"audio2"]
model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}]
return model
@pytest.fixture
def mock_model_manager(mocker, mock_model_instance):
manager = mocker.MagicMock()
manager.get_default_model_instance.return_value = mock_model_instance
mocker.patch(
"core.base.tts.app_generator_tts_publisher.ModelManager",
return_value=manager,
)
return manager
@pytest.fixture(autouse=True)
def patch_threads(mocker):
"""Prevent real threads from starting during tests"""
mocker.patch("threading.Thread.start", return_value=None)
# =========================
# AudioTrunk Tests
# =========================
class TestAudioTrunk:
def test_audio_trunk_initialization(self):
trunk = AudioTrunk("responding", b"data")
assert trunk.status == "responding"
assert trunk.audio == b"data"
# =========================
# _invoice_tts Tests
# =========================
class TestInvoiceTTS:
@pytest.mark.parametrize(
"text",
[None, "", " "],
)
def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance):
result = _invoice_tts(text, mock_model_instance, "tenant", "voice1")
assert result is None
mock_model_instance.invoke_tts.assert_not_called()
def test_invoice_tts_valid_text(self, mock_model_instance):
result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1")
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="hello",
user="responding_tts",
tenant_id="tenant",
voice="voice1",
)
assert result == [b"audio1", b"audio2"]
# =========================
# _process_future Tests
# =========================
class TestProcessFuture:
def test_process_future_normal_flow(self):
future_queue = queue.Queue()
audio_queue = queue.Queue()
future = MagicMock()
future.result.return_value = [b"abc"]
future_queue.put(future)
future_queue.put(None)
_process_future(future_queue, audio_queue)
first = audio_queue.get()
assert first.status == "responding"
assert first.audio == base64.b64encode(b"abc")
finish = audio_queue.get()
assert finish.status == "finish"
def test_process_future_empty_result(self):
future_queue = queue.Queue()
audio_queue = queue.Queue()
future = MagicMock()
future.result.return_value = None
future_queue.put(future)
future_queue.put(None)
_process_future(future_queue, audio_queue)
finish = audio_queue.get()
assert finish.status == "finish"
def test_process_future_exception(self, mocker):
future_queue = queue.Queue()
audio_queue = queue.Queue()
future = MagicMock()
future.result.side_effect = Exception("error")
future_queue.put(future)
_process_future(future_queue, audio_queue)
finish = audio_queue.get()
assert finish.status == "finish"
# =========================
# AppGeneratorTTSPublisher Tests
# =========================
class TestAppGeneratorTTSPublisher:
def test_initialization_valid_voice(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
assert publisher.voice == "voice1"
assert publisher.max_sentence == 2
assert publisher.msg_text == ""
def test_initialization_invalid_voice_fallback(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice")
assert publisher.voice == "voice1"
def test_publish_puts_message_in_queue(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
message = MagicMock()
publisher.publish(message)
assert publisher._msg_queue.get() == message
def test_check_and_get_audio_no_audio(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
result = publisher.check_and_get_audio()
assert result is None
def test_check_and_get_audio_non_finish_event(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
trunk = AudioTrunk("responding", b"abc")
publisher._audio_queue.put(trunk)
result = publisher.check_and_get_audio()
assert result.status == "responding"
assert publisher._last_audio_event == trunk
def test_check_and_get_audio_finish_event(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
finish_trunk = AudioTrunk("finish", b"")
publisher._audio_queue.put(finish_trunk)
result = publisher.check_and_get_audio()
assert result.status == "finish"
publisher.executor.shutdown.assert_called_once()
def test_check_and_get_audio_cached_finish(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
publisher._last_audio_event = AudioTrunk("finish", b"")
result = publisher.check_and_get_audio()
assert result.status == "finish"
publisher.executor.shutdown.assert_called_once()
@pytest.mark.parametrize(
("text", "expected_sentences", "expected_remaining"),
[
("Hello world.", ["Hello world."], ""),
("Hello world! How are you?", ["Hello world!", " How are you?"], ""),
("No punctuation", [], "No punctuation"),
("", [], ""),
],
)
def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
sentences, remaining = publisher._extract_sentence(text)
assert sentences == expected_sentences
assert remaining == expected_remaining
def test_runtime_handles_none_message_with_buffer(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
publisher.msg_text = "Hello."
publisher._msg_queue.put(None)
publisher._runtime()
publisher.executor.submit.assert_called_once()
def test_runtime_handles_none_message_without_buffer(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
publisher.msg_text = " "
publisher._msg_queue.put(None)
publisher._runtime()
publisher.executor.submit.assert_not_called()
def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
# Force sentence extraction to hit threshold condition
mocker.patch.object(
publisher,
"_extract_sentence",
return_value=(["Hello world.", " Second sentence."], ""),
)
from core.app.entities.queue_entities import QueueTextChunkEvent
event = MagicMock()
event.event = MagicMock(spec=QueueTextChunkEvent)
event.event.text = "Hello world. Second sentence."
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.executor.submit.called
def test_runtime_handles_text_chunk_event(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueTextChunkEvent
event = MagicMock()
event.event = MagicMock(spec=QueueTextChunkEvent)
event.event.text = "Hello world."
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.executor.submit.called
def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueNodeSucceededEvent
event = MagicMock()
event.event = MagicMock(spec=QueueNodeSucceededEvent)
event.event.outputs = {"output": "Hello world."}
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.executor.submit.called
def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueNodeSucceededEvent
event = MagicMock()
event.event = MagicMock(spec=QueueNodeSucceededEvent)
event.event.outputs = None
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
publisher.executor.submit.assert_not_called()
def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueAgentMessageEvent
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
TextPromptMessageContent,
)
chunk = LLMResultChunk(
model="model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data="Hello "),
ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"),
]
),
),
)
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.msg_text == "Hello "
def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueAgentMessageEvent
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
chunk = LLMResultChunk(
model="model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
),
)
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.msg_text == ""
def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueTextChunkEvent
event = MagicMock()
event.event = MagicMock(spec=QueueTextChunkEvent)
event.event.text = "Hello world. Another sentence."
mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None))
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.msg_text == ""
def test_runtime_exception_path(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher._msg_queue = MagicMock()
publisher._msg_queue.get.side_effect = Exception("error")
publisher._runtime()

View File

@@ -1,197 +0,0 @@
from unittest.mock import MagicMock
import pytest
import core.callback_handler.agent_tool_callback_handler as module
# -----------------------------
# Fixtures
# -----------------------------
@pytest.fixture
def enable_debug(mocker):
mocker.patch.object(module.dify_config, "DEBUG", True)
@pytest.fixture
def disable_debug(mocker):
mocker.patch.object(module.dify_config, "DEBUG", False)
@pytest.fixture
def mock_print(mocker):
return mocker.patch("builtins.print")
@pytest.fixture
def handler():
return module.DifyAgentCallbackHandler(color="blue")
# -----------------------------
# get_colored_text Tests
# -----------------------------
class TestGetColoredText:
@pytest.mark.parametrize(
("color", "expected_code"),
[
("blue", "36;1"),
("yellow", "33;1"),
("pink", "38;5;200"),
("green", "32;1"),
("red", "31;1"),
],
)
def test_get_colored_text_valid_colors(self, color, expected_code):
text = "hello"
result = module.get_colored_text(text, color)
assert expected_code in result
assert text in result
assert result.endswith("\u001b[0m")
def test_get_colored_text_invalid_color_raises(self):
with pytest.raises(KeyError):
module.get_colored_text("hello", "invalid")
def test_get_colored_text_empty_string(self):
result = module.get_colored_text("", "green")
assert "\u001b[" in result
# -----------------------------
# print_text Tests
# -----------------------------
class TestPrintText:
def test_print_text_without_color(self, mock_print):
module.print_text("hello")
mock_print.assert_called_once_with("hello", end="", file=None)
def test_print_text_with_color(self, mocker, mock_print):
mock_get_color = mocker.patch(
"core.callback_handler.agent_tool_callback_handler.get_colored_text",
return_value="colored_text",
)
module.print_text("hello", color="green")
mock_get_color.assert_called_once_with("hello", "green")
mock_print.assert_called_once_with("colored_text", end="", file=None)
def test_print_text_with_file_flush(self, mocker):
mock_file = MagicMock()
mock_print = mocker.patch("builtins.print")
module.print_text("hello", file=mock_file)
mock_print.assert_called_once_with("hello", end="", file=mock_file)
mock_file.flush.assert_called_once()
def test_print_text_with_end_parameter(self, mock_print):
module.print_text("hello", end="\n")
mock_print.assert_called_once_with("hello", end="\n", file=None)
# -----------------------------
# DifyAgentCallbackHandler Tests
# -----------------------------
class TestDifyAgentCallbackHandler:
def test_init_default_color(self):
handler = module.DifyAgentCallbackHandler()
assert handler.color == "green"
assert handler.current_loop == 1
def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_start("tool1", {"a": 1})
mock_print_text.assert_called()
def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_start("tool1", {"a": 1})
mock_print_text.assert_not_called()
def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
mock_trace_manager = MagicMock()
handler.on_tool_end(
tool_name="tool1",
tool_inputs={"a": 1},
tool_outputs="output",
message_id="msg1",
timer=123,
trace_manager=mock_trace_manager,
)
assert mock_print_text.call_count >= 1
mock_trace_manager.add_trace_task.assert_called_once()
def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_end(
tool_name="tool1",
tool_inputs={},
tool_outputs="output",
)
assert mock_print_text.call_count >= 1
def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_error(Exception("error"))
mock_print_text.assert_called_once()
def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_error(Exception("error"))
mock_print_text.assert_not_called()
@pytest.mark.parametrize("thought", ["thinking", ""])
def test_on_agent_start(self, handler, enable_debug, mocker, thought):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_agent_start(thought)
mock_print_text.assert_called()
def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
current_loop = handler.current_loop
handler.on_agent_finish()
assert handler.current_loop == current_loop + 1
mock_print_text.assert_called()
def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_datasource_start("ds1", {"x": 1})
mock_print_text.assert_called_once()
def test_ignore_agent_property(self, disable_debug, handler):
assert handler.ignore_agent is True
def test_ignore_chat_model_property(self, disable_debug, handler):
assert handler.ignore_chat_model is True
def test_ignore_properties_when_debug_enabled(self, enable_debug, handler):
assert handler.ignore_agent is False
assert handler.ignore_chat_model is False

View File

@@ -1,162 +0,0 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import (
DatasetIndexToolCallbackHandler,
)
@pytest.fixture
def mock_queue_manager(mocker):
return mocker.Mock()
@pytest.fixture
def handler(mock_queue_manager, mocker):
mocker.patch(
"core.callback_handler.index_tool_callback_handler.db",
)
return DatasetIndexToolCallbackHandler(
queue_manager=mock_queue_manager,
app_id="app-1",
message_id="msg-1",
user_id="user-1",
invoke_from=mocker.Mock(),
)
class TestOnQuery:
@pytest.mark.parametrize(
("invoke_from", "expected_role"),
[
(InvokeFrom.EXPLORE, "account"),
(InvokeFrom.DEBUGGER, "account"),
(InvokeFrom.WEB_APP, "end_user"),
],
)
def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role):
# Arrange
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
handler = DatasetIndexToolCallbackHandler(
queue_manager=mock_queue_manager,
app_id="app-1",
message_id="msg-1",
user_id="user-1",
invoke_from=mocker.Mock(),
)
handler._invoke_from = invoke_from
# Act
handler.on_query("test query", "dataset-1")
# Assert
mock_db.session.add.assert_called_once()
dataset_query = mock_db.session.add.call_args.args[0]
assert dataset_query.created_by_role == expected_role
mock_db.session.commit.assert_called_once()
def test_on_query_none_values(self, mocker, mock_queue_manager):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
handler = DatasetIndexToolCallbackHandler(
queue_manager=mock_queue_manager,
app_id=None,
message_id=None,
user_id=None,
invoke_from=None,
)
handler.on_query(None, None)
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
class TestOnToolEnd:
def test_on_tool_end_no_metadata(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
document = mocker.Mock()
document.metadata = None
handler.on_tool_end([document])
mock_db.session.commit.assert_not_called()
def test_on_tool_end_dataset_document_not_found(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
mock_db.session.scalar.return_value = None
document = mocker.Mock()
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
handler.on_tool_end([document])
mock_db.session.scalar.assert_called_once()
def test_on_tool_end_parent_child_index_with_child(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
mock_dataset_doc = mocker.Mock()
from core.callback_handler.index_tool_callback_handler import IndexStructureType
mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX
mock_dataset_doc.dataset_id = "dataset-1"
mock_dataset_doc.id = "doc-1"
mock_child_chunk = mocker.Mock()
mock_child_chunk.segment_id = "segment-1"
mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk]
document = mocker.Mock()
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_query.update.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
mock_dataset_doc = mocker.Mock()
mock_dataset_doc.doc_form = "OTHER"
mock_db.session.scalar.return_value = mock_dataset_doc
document = mocker.Mock()
document.metadata = {
"document_id": "doc-1",
"doc_id": "node-1",
"dataset_id": "dataset-1",
}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_query.update.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_empty_documents(self, handler):
handler.on_tool_end([])
class TestReturnRetrieverResourceInfo:
def test_publish_called(self, handler, mock_queue_manager, mocker):
mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent")
resources = [mocker.Mock()]
handler.return_retriever_resource_info(resources)
mock_queue_manager.publish.assert_called_once()

View File

@@ -1,184 +0,0 @@
from unittest.mock import MagicMock, call
import pytest
from core.callback_handler.workflow_tool_callback_handler import (
DifyWorkflowCallbackHandler,
)
class DummyToolInvokeMessage:
"""Lightweight dummy to simulate ToolInvokeMessage behavior."""
def __init__(self, json_value: str):
self._json_value = json_value
def model_dump_json(self):
return self._json_value
@pytest.fixture
def handler():
"""Fixture to create handler instance with deterministic color."""
instance = DifyWorkflowCallbackHandler()
instance.color = "blue"
return instance
@pytest.fixture
def mock_print_text(mocker):
"""Mock print_text to avoid real stdout printing."""
return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text")
class TestDifyWorkflowCallbackHandler:
def test_on_tool_execution_single_output_success(self, handler, mock_print_text):
# Arrange
tool_name = "test_tool"
tool_inputs = {"a": 1}
message = DummyToolInvokeMessage('{"key": "value"}')
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs=tool_inputs,
tool_outputs=[message],
)
)
# Assert
assert results == [message]
assert mock_print_text.call_count == 4
mock_print_text.assert_has_calls(
[
call("\n[on_tool_execution]\n", color="blue"),
call("Tool: test_tool\n", color="blue"),
call(
"Outputs: " + message.model_dump_json()[:1000] + "\n",
color="blue",
),
call("\n"),
]
)
def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text):
# Arrange
tool_name = "multi_tool"
outputs = [
DummyToolInvokeMessage('{"id": 1}'),
DummyToolInvokeMessage('{"id": 2}'),
]
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=outputs,
)
)
# Assert
assert results == outputs
assert mock_print_text.call_count == 4 * len(outputs)
def test_on_tool_execution_empty_iterable(self, handler, mock_print_text):
# Arrange
tool_name = "empty_tool"
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[],
)
)
# Assert
assert results == []
mock_print_text.assert_not_called()
@pytest.mark.parametrize(
("invalid_outputs", "expected_exception"),
[
(None, TypeError),
(123, TypeError),
("not_iterable", AttributeError),
],
)
def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception):
# Arrange
tool_name = "invalid_tool"
# Act & Assert
with pytest.raises(expected_exception):
list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=invalid_outputs,
)
)
def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text):
# Arrange
tool_name = "long_json_tool"
long_json = "x" * 1500
message = DummyToolInvokeMessage(long_json)
# Act
list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[message],
)
)
# Assert
expected_truncated = long_json[:1000]
mock_print_text.assert_any_call(
"Outputs: " + expected_truncated + "\n",
color="blue",
)
def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text):
# Arrange
tool_name = "exception_tool"
bad_message = MagicMock()
bad_message.model_dump_json.side_effect = ValueError("JSON error")
# Act & Assert
with pytest.raises(ValueError):
list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[bad_message],
)
)
# Ensure first two prints happened before failure
assert mock_print_text.call_count >= 2
def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text):
# Arrange
tool_name = "optional_params_tool"
message = DummyToolInvokeMessage('{"data": "ok"}')
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[message],
message_id=None,
timer=None,
trace_manager=None,
)
)
assert results == [message]
assert mock_print_text.call_count == 4

View File

@@ -1,9 +0,0 @@
from core.entities.agent_entities import PlanningStrategy
def test_planning_strategy_values_are_stable() -> None:
# Arrange / Act / Assert
assert PlanningStrategy.ROUTER.value == "router"
assert PlanningStrategy.REACT_ROUTER.value == "react_router"
assert PlanningStrategy.REACT.value == "react"
assert PlanningStrategy.FUNCTION_CALL.value == "function_call"

View File

@@ -1,18 +0,0 @@
from core.entities.document_task import DocumentTask
def test_document_task_keeps_indexing_identifiers() -> None:
# Arrange
document_ids = ("doc-1", "doc-2")
# Act
task = DocumentTask(
tenant_id="tenant-1",
dataset_id="dataset-1",
document_ids=document_ids,
)
# Assert
assert task.tenant_id == "tenant-1"
assert task.dataset_id == "dataset-1"
assert task.document_ids == document_ids

View File

@@ -1,7 +0,0 @@
from core.entities.embedding_type import EmbeddingInputType
def test_embedding_input_type_values_are_stable() -> None:
# Arrange / Act / Assert
assert EmbeddingInputType.DOCUMENT.value == "document"
assert EmbeddingInputType.QUERY.value == "query"

View File

@@ -1,45 +0,0 @@
from core.entities.execution_extra_content import (
ExecutionExtraContentDomainModel,
HumanInputContent,
HumanInputFormDefinition,
HumanInputFormSubmissionData,
)
from dify_graph.nodes.human_input.entities import FormInput, UserAction
from dify_graph.nodes.human_input.enums import FormInputType
from models.execution_extra_content import ExecutionContentType
def test_human_input_content_defaults_and_domain_alias() -> None:
# Arrange
form_definition = HumanInputFormDefinition(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="Please confirm",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")],
actions=[UserAction(id="confirm", title="Confirm")],
resolved_default_values={"answer": "yes"},
expiration_time=1_700_000_000,
)
submission_data = HumanInputFormSubmissionData(
node_id="node-1",
node_title="Human Input",
rendered_content="Please confirm",
action_id="confirm",
action_text="Confirm",
)
# Act
content = HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_definition=form_definition,
form_submission_data=submission_data,
)
# Assert
assert form_definition.model_config.get("frozen") is True
assert content.type == ExecutionContentType.HUMAN_INPUT
assert content.form_definition is form_definition
assert content.form_submission_data is submission_data
assert ExecutionExtraContentDomainModel is HumanInputContent

View File

@@ -1,45 +0,0 @@
from core.entities.knowledge_entities import (
PipelineDataset,
PipelineDocument,
PipelineGenerateResponse,
)
def test_pipeline_dataset_normalizes_none_description() -> None:
# Arrange / Act
dataset = PipelineDataset(
id="dataset-1",
name="Dataset",
description=None,
chunk_structure="parent-child",
)
# Assert
assert dataset.description == ""
def test_pipeline_generate_response_builds_nested_models() -> None:
# Arrange
dataset = PipelineDataset(
id="dataset-1",
name="Dataset",
description="Knowledge base",
chunk_structure="parent-child",
)
document = PipelineDocument(
id="doc-1",
position=1,
data_source_type="file",
data_source_info={"name": "spec.pdf"},
name="spec.pdf",
indexing_status="completed",
enabled=True,
)
# Act
response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document])
# Assert
assert response.batch == "batch-1"
assert response.dataset.id == "dataset-1"
assert response.documents[0].id == "doc-1"

View File

@@ -1,450 +0,0 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.entities import mcp_provider as mcp_provider_module
from core.entities.mcp_provider import (
DEFAULT_EXPIRES_IN,
DEFAULT_TOKEN_TYPE,
MCPProviderEntity,
)
from core.mcp.types import OAuthTokens
def _build_mcp_provider_entity() -> MCPProviderEntity:
now = datetime(2025, 1, 1, tzinfo=UTC)
return MCPProviderEntity(
id="provider-1",
provider_id="server-1",
name="Example MCP",
tenant_id="tenant-1",
user_id="user-1",
server_url="encrypted-server-url",
headers={},
timeout=30,
sse_read_timeout=300,
authed=False,
credentials={},
tools=[],
icon={"en_US": "icon.png"},
created_at=now,
updated_at=now,
)
def test_from_db_model_maps_fields() -> None:
# Arrange
now = datetime(2025, 1, 1, tzinfo=UTC)
db_provider = SimpleNamespace(
id="provider-1",
server_identifier="server-1",
name="Example MCP",
tenant_id="tenant-1",
user_id="user-1",
server_url="encrypted-server-url",
headers={"Authorization": "enc"},
timeout=15,
sse_read_timeout=120,
authed=True,
credentials={"access_token": "enc-token"},
tool_dict=[{"name": "search"}],
icon=None,
created_at=now,
updated_at=now,
)
# Act
entity = MCPProviderEntity.from_db_model(db_provider)
# Assert
assert entity.provider_id == "server-1"
assert entity.tools == [{"name": "search"}]
assert entity.icon == ""
def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
entity = _build_mcp_provider_entity()
monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com")
# Act
redirect_url = entity.redirect_url
# Assert
assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback"
def test_client_metadata_for_authorization_code_flow() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
# Act
metadata = entity.client_metadata
# Assert
assert metadata.grant_types == ["refresh_token", "authorization_code"]
assert metadata.redirect_uris == [entity.redirect_url]
assert metadata.response_types == ["code"]
def test_client_metadata_for_client_credentials_flow() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {"client_information": {"grant_types": ["client_credentials"]}}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
# Act
metadata = entity.client_metadata
# Assert
assert metadata.grant_types == ["refresh_token", "client_credentials"]
assert metadata.redirect_uris == []
assert metadata.response_types == []
def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {
"grant_type": "client_credentials",
"client_information": {"grant_types": ["authorization_code"]},
}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
# Act
metadata = entity.client_metadata
# Assert
assert metadata.grant_types == ["refresh_token", "authorization_code"]
assert metadata.redirect_uris == [entity.redirect_url]
assert metadata.response_types == ["code"]
def test_provider_icon_returns_icon_dict_as_is() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
# Act
icon = entity.provider_icon
# Assert
assert icon == {"en_US": "icon.png"}
def test_provider_icon_uses_signed_url_for_plain_path() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"})
with patch(
"core.entities.mcp_provider.file_helpers.get_signed_file_url",
return_value="https://signed.example.com/icons/mcp.png",
) as mock_get_signed_url:
# Act
icon = entity.provider_icon
# Assert
mock_get_signed_url.assert_called_once_with("icons/mcp.png")
assert icon == "https://signed.example.com/icons/mcp.png"
def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
# Act
response = entity.to_api_response(include_sensitive=False)
# Assert
assert response["author"] == "Anonymous"
assert response["masked_headers"] == {}
assert response["is_dynamic_registration"] is True
assert "authentication" not in response
def test_to_api_response_with_sensitive_data_includes_masked_values() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(
update={
"credentials": {"client_information": {"is_dynamic_registration": False}},
"icon": {"en_US": "icon.png"},
}
)
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}):
with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}):
# Act
response = entity.to_api_response(user_name="Rajat", include_sensitive=True)
# Assert
assert response["author"] == "Rajat"
assert response["masked_headers"] == {"Authorization": "Be****"}
assert response["authentication"] == {"client_id": "cl****"}
assert response["is_dynamic_registration"] is False
def test_retrieve_client_information_decrypts_nested_secret() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt:
# Act
client_info = entity.retrieve_client_information()
# Assert
assert client_info is not None
assert client_info.client_id == "client-1"
assert client_info.client_secret == "plain-secret"
mock_decrypt.assert_called_once_with("tenant-1", "enc-secret")
def test_retrieve_client_information_returns_none_for_missing_data() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
# Act
result_empty = entity.retrieve_client_information()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
# Act
result_invalid = entity.retrieve_client_information()
# Assert
assert result_empty is None
assert result_invalid is None
def test_masked_server_url_hides_path_segments() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(
MCPProviderEntity,
"decrypt_server_url",
return_value="https://api.example.com/v1/mcp?query=1",
):
# Act
masked_url = entity.masked_server_url()
# Assert
assert masked_url == "https://api.example.com/******?query=1"
def test_mask_value_covers_short_and_long_values() -> None:
# Arrange
entity = _build_mcp_provider_entity()
# Act
short_masked = entity._mask_value("short")
long_masked = entity._mask_value("abcdefghijkl")
# Assert
assert short_masked == "*****"
assert long_masked == "ab********kl"
def test_masked_headers_masks_all_decrypted_header_values() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}):
# Act
masked = entity.masked_headers()
# Assert
assert masked == {"Authorization": "ab****gh"}
def test_masked_credentials_handles_nested_secret_fields() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {
"client_information": {
"client_id": "client-id",
"encrypted_client_secret": "encrypted-value",
"client_secret": "plain-secret",
}
}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"):
# Act
masked = entity.masked_credentials()
# Assert
assert masked["client_id"] == "cl*****id"
assert masked["client_secret"] == "pl********et"
def test_masked_credentials_returns_empty_for_missing_client_information() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
# Act
masked_empty = entity.masked_credentials()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
# Act
masked_invalid = entity.masked_credentials()
# Assert
assert masked_empty == {}
assert masked_invalid == {}
def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
with patch.object(
MCPProviderEntity,
"decrypt_credentials",
return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"},
):
# Act
tokens = entity.retrieve_tokens()
# Assert
assert isinstance(tokens, OAuthTokens)
assert tokens.access_token == "token"
assert tokens.token_type == DEFAULT_TOKEN_TYPE
assert tokens.expires_in == DEFAULT_EXPIRES_IN
assert tokens.refresh_token == "refresh"
def test_retrieve_tokens_returns_none_when_access_token_missing() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt:
# Act
tokens = entity.retrieve_tokens()
# Assert
mock_decrypt.assert_called_once()
assert tokens is None
def test_decrypt_server_url_delegates_to_encrypter() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock:
# Act
decrypted = entity.decrypt_server_url()
# Assert
mock.assert_called_once_with("tenant-1", "encrypted-server-url")
assert decrypted == "https://api.example.com"
def test_decrypt_authentication_injects_authorization_for_oauth() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}})
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}):
with patch.object(
MCPProviderEntity,
"retrieve_tokens",
return_value=OAuthTokens(access_token="abc123", token_type="bearer"),
):
# Act
headers = entity.decrypt_authentication()
# Assert
assert headers["Authorization"] == "Bearer abc123"
def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(
update={"authed": True, "headers": {"Authorization": "encrypted-header"}}
)
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}):
with patch.object(
MCPProviderEntity,
"retrieve_tokens",
return_value=OAuthTokens(access_token="abc", token_type="bearer"),
) as mock_tokens:
# Act
headers = entity.decrypt_authentication()
# Assert
mock_tokens.assert_not_called()
assert headers == {"Authorization": "existing"}
def test_decrypt_dict_returns_empty_for_empty_input() -> None:
# Arrange
entity = _build_mcp_provider_entity()
# Act
decrypted = entity._decrypt_dict({})
# Assert
assert decrypted == {}
def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None:
# Arrange
entity = _build_mcp_provider_entity()
input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""}
# Act
result = entity._decrypt_dict(input_data)
# Assert
assert result is input_data
def test_decrypt_dict_only_decrypts_top_level_string_values() -> None:
# Arrange
entity = _build_mcp_provider_entity()
decryptor = Mock()
decryptor.decrypt.return_value = {"api_key": "plain-key"}
def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache):
assert tenant_id == "tenant-1"
assert any(item.name == "api_key" for item in config)
return decryptor, None
with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter):
# Act
result = entity._decrypt_dict(
{
"api_key": "encrypted-key",
"nested": {"client_id": "unchanged"},
"empty": "",
"count": 2,
}
)
# Assert
decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"})
assert result["api_key"] == "plain-key"
assert result["nested"] == {"client_id": "unchanged"}
assert result["count"] == 2
def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock:
# Act
headers = entity.decrypt_headers()
credentials = entity.decrypt_credentials()
# Assert
assert mock.call_count == 2
assert headers == {"h": "v"}
assert credentials == {"c": "v"}

View File

@@ -1,92 +0,0 @@
"""Unit tests for model entity behavior and invariants.
Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus,
ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n
labels are provided via I18nObject, model metadata aligns with FetchFrom and
ModelType expectations, and ProviderEntity/ConfigurateMethod interactions
drive provider mapping behavior.
"""
import pytest
from core.entities.model_entities import (
DefaultModelEntity,
DefaultModelProviderEntity,
ModelStatus,
ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
)
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity:
return ProviderModelWithStatusEntity(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=status,
)
def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Arrange
provider_entity = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
# Act
simple_provider = SimpleModelProviderEntity(provider_entity)
# Assert
assert simple_provider.provider == "openai"
assert simple_provider.label.en_US == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]
def test_provider_model_with_status_raises_for_known_error_statuses() -> None:
# Arrange
expectations = {
ModelStatus.NO_CONFIGURE: "Model is not configured",
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
ModelStatus.NO_PERMISSION: "No permission to use this model",
ModelStatus.DISABLED: "Model is disabled",
}
for status, message in expectations.items():
# Act / Assert
with pytest.raises(ValueError, match=message):
_build_model_with_status(status).raise_for_status()
def test_provider_model_with_status_allows_active_and_credential_removed() -> None:
# Arrange
active_model = _build_model_with_status(ModelStatus.ACTIVE)
removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED)
# Act / Assert
active_model.raise_for_status()
removed_model.raise_for_status()
def test_default_model_entity_accepts_model_field_name() -> None:
# Arrange / Act
default_model = DefaultModelEntity(
model="gpt-4o-mini",
model_type=ModelType.LLM,
provider=DefaultModelProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
),
)
# Assert
assert default_model.model == "gpt-4o-mini"
assert default_model.provider.provider == "openai"

View File

@@ -1,22 +0,0 @@
from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
def test_common_parameter_type_values_are_stable() -> None:
# Arrange / Act / Assert
assert CommonParameterType.SECRET_INPUT.value == "secret-input"
assert CommonParameterType.MODEL_SELECTOR.value == "model-selector"
assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select"
assert CommonParameterType.ARRAY.value == "array"
assert CommonParameterType.OBJECT.value == "object"
def test_selector_scope_values_are_stable() -> None:
# Arrange / Act / Assert
assert AppSelectorScope.WORKFLOW.value == "workflow"
assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding"
assert ToolSelectorScope.BUILTIN.value == "builtin"

View File

@@ -1,72 +0,0 @@
import pytest
from core.entities.parameter_entities import AppSelectorScope
from core.entities.provider_entities import (
BasicProviderConfig,
ModelSettings,
ProviderConfig,
ProviderQuotaType,
)
from core.tools.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
def test_provider_quota_type_value_of_returns_enum_member() -> None:
# Arrange / Act
quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value)
# Assert
assert quota_type == ProviderQuotaType.TRIAL
def test_provider_quota_type_value_of_rejects_unknown_values() -> None:
# Arrange / Act / Assert
with pytest.raises(ValueError, match="No matching enum found"):
ProviderQuotaType.value_of("enterprise")
def test_basic_provider_config_type_value_of_handles_known_values() -> None:
# Arrange / Act
parameter_type = BasicProviderConfig.Type.value_of("text-input")
# Assert
assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT
def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None:
# Arrange / Act / Assert
with pytest.raises(ValueError, match="invalid mode value"):
BasicProviderConfig.Type.value_of("unknown")
def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None:
# Arrange
provider_config = ProviderConfig(
type=BasicProviderConfig.Type.SELECT,
name="workspace",
scope=AppSelectorScope.ALL,
options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))],
)
# Act
basic_config = provider_config.to_basic_provider_config()
# Assert
assert isinstance(basic_config, BasicProviderConfig)
assert basic_config.type == BasicProviderConfig.Type.SELECT
assert basic_config.name == "workspace"
def test_model_settings_accepts_model_field_name() -> None:
# Arrange / Act
settings = ModelSettings(
model="gpt-4o",
model_type=ModelType.LLM,
enabled=True,
load_balancing_enabled=False,
load_balancing_configs=[],
)
# Assert
assert settings.model == "gpt-4o"
assert settings.model_type == ModelType.LLM

View File

@@ -1,137 +0,0 @@
import httpx
import pytest
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from models.api_based_extension import APIBasedExtensionPoint
def test_request_success(mocker):
# Mock httpx.Client and its context manager
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_response = mocker.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"})
assert result == {"result": "success"}
mock_client_instance.request.assert_called_once_with(
method="POST",
url="http://example.com",
json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}},
headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"},
)
def test_request_with_ssrf_proxy(mocker):
# Mock dify_config
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081")
# Mock httpx.Client
mock_client = mocker.MagicMock()
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance = mock_client.__enter__.return_value
# Mock response
mock_response = mocker.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
mock_client_instance.request.return_value = mock_response
# Mock HTTPTransport
mock_transport = mocker.patch("httpx.HTTPTransport")
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
requestor.request(APIBasedExtensionPoint.PING, {})
# Verify httpx.Client was called with mounts
mock_client_class.assert_called_once()
kwargs = mock_client_class.call_args.kwargs
assert "mounts" in kwargs
assert "http://" in kwargs["mounts"]
assert "https://" in kwargs["mounts"]
assert mock_transport.call_count == 2
def test_request_with_only_one_proxy_config(mocker):
# Mock dify_config with only one proxy
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None)
# Mock httpx.Client
mock_client = mocker.MagicMock()
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance = mock_client.__enter__.return_value
# Mock response
mock_response = mocker.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
requestor.request(APIBasedExtensionPoint.PING, {})
# Verify httpx.Client was called with mounts=None (default)
mock_client_class.assert_called_once()
kwargs = mock_client_class.call_args.kwargs
assert kwargs.get("mounts") is None
def test_request_timeout(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance.request.side_effect = httpx.TimeoutException("timeout")
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
with pytest.raises(ValueError, match="request timeout"):
requestor.request(APIBasedExtensionPoint.PING, {})
def test_request_connection_error(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance.request.side_effect = httpx.RequestError("error")
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
with pytest.raises(ValueError, match="request connection error"):
requestor.request(APIBasedExtensionPoint.PING, {})
def test_request_error_status_code(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_response = mocker.MagicMock()
mock_response.status_code = 404
mock_response.text = "Not Found"
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"):
requestor.request(APIBasedExtensionPoint.PING, {})
def test_request_error_status_code_long_content(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_response = mocker.MagicMock()
mock_response.status_code = 500
mock_response.text = "A" * 200 # Testing truncation of content
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
expected_content = "A" * 100
with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"):
requestor.request(APIBasedExtensionPoint.PING, {})

View File

@@ -1,281 +0,0 @@
import json
import types
from unittest.mock import MagicMock, mock_open, patch
import pytest
from core.extension.extensible import Extensible
class TestExtensible:
def test_init(self):
tenant_id = "tenant_123"
config = {"key": "value"}
ext = Extensible(tenant_id, config)
assert ext.tenant_id == tenant_id
assert ext.config == config
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.Path.read_text")
@patch("core.extension.extensible.importlib.util.module_from_spec")
@patch("core.extension.extensible.sort_to_dict_by_position_map")
def test_scan_extensions_success(
self,
mock_sort,
mock_module_from_spec,
mock_read_text,
mock_exists,
mock_isdir,
mock_listdir,
mock_dirname,
mock_find_spec,
):
# Setup
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [
["ext1"], # package_dir
["ext1.py", "__builtin__"], # subdir_path
]
mock_isdir.return_value = True
mock_exists.return_value = True
mock_read_text.return_value = "10"
# Use types.ModuleType to avoid MagicMock __dict__ issues
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
mock_sort.side_effect = lambda position_map, data, name_func: data
# Execute
results = Extensible.scan_extensions()
# Assert
assert len(results) == 1
assert results[0].name == "ext1"
assert results[0].position == 10
assert results[0].builtin is True
assert results[0].extension_class == MockExtension
@patch("core.extension.extensible.importlib.util.find_spec")
def test_scan_extensions_package_not_found(self, mock_find_spec):
mock_find_spec.return_value = None
with pytest.raises(ImportError, match="Could not find package"):
Extensible.scan_extensions()
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
mock_find_spec.return_value = package_spec
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []]
mock_isdir.side_effect = [False, True]
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
results = Extensible.scan_extensions()
assert len(results) == 0
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.importlib.util.module_from_spec")
def test_scan_extensions_not_builtin_success(
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]]
mock_isdir.return_value = True
# exists checks: only schema.json needs to exist
mock_exists.return_value = True
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]})
with (
patch("builtins.open", mock_open(read_data=schema_content)),
patch(
"core.extension.extensible.sort_to_dict_by_position_map",
side_effect=lambda position_map, data, name_func: data,
),
):
results = Extensible.scan_extensions()
assert len(results) == 1
assert results[0].name == "ext1"
assert results[0].builtin is False
assert results[0].label == {"en": "Test"}
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.importlib.util.module_from_spec")
def test_scan_extensions_not_builtin_missing_schema(
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
mock_isdir.return_value = True
# exists: only schema.json checked, and return False
mock_exists.return_value = False
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
results = Extensible.scan_extensions()
assert len(results) == 0
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.importlib.util.module_from_spec")
@patch("core.extension.extensible.os.path.exists")
def test_scan_extensions_no_extension_class(
self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
mock_isdir.return_value = True
# Mock not builtin
mock_exists.return_value = False
mock_mod = types.ModuleType("ext1")
mock_mod.SomeOtherClass = type("SomeOtherClass", (), {})
mock_module_from_spec.return_value = mock_mod
# We need to ensure we don't crash if checking schema (but we won't reach there because class not found)
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
results = Extensible.scan_extensions()
assert len(results) == 0
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
mock_find_spec.side_effect = [package_spec, None] # No module spec
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
mock_isdir.return_value = True
with pytest.raises(ImportError, match="Failed to load module"):
Extensible.scan_extensions()
@patch("core.extension.extensible.importlib.util.find_spec")
def test_scan_extensions_general_exception(self, mock_find_spec):
mock_find_spec.side_effect = Exception("Unexpected error")
with pytest.raises(Exception, match="Unexpected error"):
Extensible.scan_extensions()
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.Path.read_text")
@patch("core.extension.extensible.importlib.util.module_from_spec")
def test_scan_extensions_builtin_without_position_file(
self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]]
mock_isdir.return_value = True
# builtin exists in listdir, but os.path.exists(builtin_file_path) returns False
mock_exists.return_value = False
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
with patch(
"core.extension.extensible.sort_to_dict_by_position_map",
side_effect=lambda position_map, data, name_func: data,
):
results = Extensible.scan_extensions()
assert len(results) == 1
assert results[0].position == 0

View File

@@ -1,90 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.extension.extensible import ExtensionModule, ModuleExtension
from core.extension.extension import Extension
class TestExtension:
def setup_method(self):
# Reset the private class attribute before each test
Extension._Extension__module_extensions = {}
def test_init(self):
# Mock scan_extensions for Moderation and ExternalDataTool
mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")}
mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")}
extension = Extension()
# We need to mock scan_extensions on the classes defined in Extension.module_classes
with (
patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions),
patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions),
):
extension.init()
# Check if internal state is updated
internal_state = Extension._Extension__module_extensions
assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions
assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions
def test_module_extensions_success(self):
# Setup data
mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")}
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions}
extension = Extension()
result = extension.module_extensions(ExtensionModule.MODERATION.value)
assert len(result) == 2
assert any(e.name == "name1" for e in result)
assert any(e.name == "name2" for e in result)
def test_module_extensions_not_found(self):
extension = Extension()
with pytest.raises(ValueError, match="Extension Module unknown not found"):
extension.module_extensions("unknown")
def test_module_extension_success(self):
mock_ext = ModuleExtension(name="test_ext")
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
extension = Extension()
result = extension.module_extension(ExtensionModule.MODERATION, "test_ext")
assert result == mock_ext
def test_module_extension_module_not_found(self):
extension = Extension()
# ExtensionModule.MODERATION is "moderation"
with pytest.raises(ValueError, match="Extension Module moderation not found"):
extension.module_extension(ExtensionModule.MODERATION, "any")
def test_module_extension_extension_not_found(self):
# We need a non-empty dict because 'if not module_extensions' in extension.py
# returns True for an empty dict, which raises the module not found error instead.
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}}
extension = Extension()
with pytest.raises(ValueError, match="Extension unknown not found"):
extension.module_extension(ExtensionModule.MODERATION, "unknown")
def test_extension_class_success(self):
class MockClass:
pass
mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass)
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
extension = Extension()
result = extension.extension_class(ExtensionModule.MODERATION, "test_ext")
assert result == MockClass
def test_extension_class_none(self):
mock_ext = ModuleExtension(name="test_ext", extension_class=None)
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
extension = Extension()
with pytest.raises(AssertionError):
extension.extension_class(ExtensionModule.MODERATION, "test_ext")

View File

@@ -1,145 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.external_data_tool.api.api import ApiExternalDataTool
from models.api_based_extension import APIBasedExtensionPoint
def test_api_external_data_tool_name():
assert ApiExternalDataTool.name == "api"
@patch("core.external_data_tool.api.api.db")
def test_validate_config_success(mock_db):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = mock_extension
# Should not raise exception
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
def test_validate_config_missing_id():
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiExternalDataTool.validate_config("tenant_id", {})
@patch("core.external_data_tool.api.api.db")
def test_validate_config_invalid_id(mock_db):
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="api_based_extension_id is invalid"):
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
@pytest.fixture
def api_tool():
# Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel
return ApiExternalDataTool(
tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"}
)
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor = mock_requestor_class.return_value
mock_requestor.request.return_value = {"result": "success_result"}
res = api_tool.query({"input1": "value1"}, "query_str")
assert res == "success_result"
mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key")
mock_requestor.request.assert_called_once_with(
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"},
)
def test_query_missing_config():
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1")
api_tool.config = None # Force None
with pytest.raises(ValueError, match="config is required"):
api_tool.query({}, "")
def test_query_missing_extension_id():
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"})
with pytest.raises(AssertionError, match="api_based_extension_id is required"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
def test_query_invalid_extension(mock_db, api_tool):
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor_class.side_effect = Exception("init error")
with pytest.raises(ValueError, match=".*error: init error"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor = mock_requestor_class.return_value
mock_requestor.request.return_value = {"other": "value"}
with pytest.raises(ValueError, match=".*error: result not found in response"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor = mock_requestor_class.return_value
mock_requestor.request.return_value = {"result": 123} # Not a string
with pytest.raises(ValueError, match=".*error: result is not string"):
api_tool.query({}, "")

View File

@@ -1,66 +0,0 @@
import pytest
from core.extension.extensible import ExtensionModule
from core.external_data_tool.base import ExternalDataTool
class TestExternalDataTool:
def test_module_attribute(self):
assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL
def test_init(self):
# Create a concrete subclass to test init
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
return super().validate_config(tenant_id, config)
def query(self, inputs: dict, query: str | None = None) -> str:
return super().query(inputs, query)
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"})
assert tool.tenant_id == "tenant_1"
assert tool.app_id == "app_1"
assert tool.variable == "var_1"
assert tool.config == {"key": "value"}
def test_init_without_config(self):
# Create a concrete subclass to test init
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
pass
def query(self, inputs: dict, query: str | None = None) -> str:
return ""
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
assert tool.tenant_id == "tenant_1"
assert tool.app_id == "app_1"
assert tool.variable == "var_1"
assert tool.config is None
def test_validate_config_raises_not_implemented(self):
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
return super().validate_config(tenant_id, config)
def query(self, inputs: dict, query: str | None = None) -> str:
return ""
with pytest.raises(NotImplementedError):
ConcreteTool.validate_config("tenant_1", {})
def test_query_raises_not_implemented(self):
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
pass
def query(self, inputs: dict, query: str | None = None) -> str:
return super().query(inputs, query)
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
with pytest.raises(NotImplementedError):
tool.query({})

View File

@@ -1,115 +0,0 @@
from unittest.mock import patch
import pytest
from flask import Flask
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.external_data_fetch import ExternalDataFetch
class TestExternalDataFetch:
@pytest.fixture
def app(self):
app = Flask(__name__)
return app
def test_fetch_success(self, app):
with app.app_context():
fetcher = ExternalDataFetch()
# Setup mocks
tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"})
tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"})
external_data_tools = [tool1, tool2]
inputs = {"input_key": "input_value"}
query = "test query"
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
# Create distinct mock instances for each tool to ensure deterministic results
# This approach is robust regardless of thread scheduling order
from unittest.mock import MagicMock
def factory_side_effect(*args, **kwargs):
variable = kwargs.get("variable")
mock_instance = MagicMock()
if variable == "var1":
mock_instance.query.return_value = "result1"
elif variable == "var2":
mock_instance.query.return_value = "result2"
return mock_instance
MockFactory.side_effect = factory_side_effect
result_inputs = fetcher.fetch(
tenant_id="tenant1",
app_id="app1",
external_data_tools=external_data_tools,
inputs=inputs,
query=query,
)
# Each tool gets its deterministic result regardless of thread completion order
assert result_inputs["var1"] == "result1"
assert result_inputs["var2"] == "result2"
assert result_inputs["input_key"] == "input_value"
assert len(result_inputs) == 3
# Verify factory calls
assert MockFactory.call_count == 2
MockFactory.assert_any_call(
name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"}
)
MockFactory.assert_any_call(
name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"}
)
def test_fetch_no_tools(self):
# We don't necessarily need app_context if there are no tools,
# but fetch calls current_app._get_current_object() only inside the loop.
# Wait, let's look at the code.
# for tool in external_data_tools:
# executor.submit(..., current_app._get_current_object(), ...)
# So if external_data_tools is empty, it shouldn't access current_app.
fetcher = ExternalDataFetch()
inputs = {"input_key": "input_value"}
result_inputs = fetcher.fetch(
tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query"
)
assert result_inputs == inputs
assert result_inputs is not inputs # Should be a copy
def test_fetch_with_none_variable(self, app):
with app.app_context():
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
# Patch _query_external_data_tool to return None variable
with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query:
mock_query.return_value = (None, "some_result")
result_inputs = fetcher.fetch(
tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q"
)
assert "var1" not in result_inputs
assert result_inputs == {"in": "val"}
def test_query_external_data_tool(self, app):
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
mock_factory_instance = MockFactory.return_value
mock_factory_instance.query.return_value = "query_result"
var, res = fetcher._query_external_data_tool(
flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q"
)
assert var == "var1"
assert res == "query_result"
MockFactory.assert_called_once_with(
name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"}
)
mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q")

View File

@@ -1,58 +0,0 @@
from unittest.mock import MagicMock, patch
from core.extension.extensible import ExtensionModule
from core.external_data_tool.factory import ExternalDataToolFactory
def test_external_data_tool_factory_init():
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
mock_extension_class = MagicMock()
mock_code_based_extension.extension_class.return_value = mock_extension_class
name = "test_tool"
tenant_id = "tenant_123"
app_id = "app_456"
variable = "var_v"
config = {"key": "value"}
factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config)
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
mock_extension_class.assert_called_once_with(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)
def test_external_data_tool_factory_validate_config():
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
mock_extension_class = MagicMock()
mock_code_based_extension.extension_class.return_value = mock_extension_class
name = "test_tool"
tenant_id = "tenant_123"
config = {"key": "value"}
ExternalDataToolFactory.validate_config(name, tenant_id, config)
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
mock_extension_class.validate_config.assert_called_once_with(tenant_id, config)
def test_external_data_tool_factory_query():
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
mock_extension_class = MagicMock()
mock_extension_instance = MagicMock()
mock_extension_class.return_value = mock_extension_instance
mock_code_based_extension.extension_class.return_value = mock_extension_class
mock_extension_instance.query.return_value = "query_result"
factory = ExternalDataToolFactory("name", "tenant", "app", "var", {})
inputs = {"input_key": "input_value"}
query = "search_query"
result = factory.query(inputs, query)
assert result == "query_result"
mock_extension_instance.query.assert_called_once_with(inputs, query)

View File

@@ -1,103 +0,0 @@
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
)
class TestRuleConfigGeneratorOutputParser:
def test_get_format_instructions(self):
parser = RuleConfigGeneratorOutputParser()
instructions = parser.get_format_instructions()
assert instructions == (
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
)
def test_parse_success(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": ["var1", "var2"],
"opening_statement": "Hello!"
}
```
"""
result = parser.parse(text)
assert result["prompt"] == "This is a prompt"
assert result["variables"] == ["var1", "var2"]
assert result["opening_statement"] == "Hello!"
def test_parse_invalid_json(self):
parser = RuleConfigGeneratorOutputParser()
text = "invalid json"
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Parsing text" in str(excinfo.value)
assert "could not find json block in the output" in str(excinfo.value)
def test_parse_missing_keys(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": ["var1", "var2"]
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "expected key `opening_statement` to be present" in str(excinfo.value)
def test_parse_wrong_type_prompt(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": 123,
"variables": ["var1", "var2"],
"opening_statement": "Hello!"
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Expected 'prompt' to be a string" in str(excinfo.value)
def test_parse_wrong_type_variables(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": "not a list",
"opening_statement": "Hello!"
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Expected 'variables' to be a list" in str(excinfo.value)
def test_parse_wrong_type_opening_statement(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": ["var1", "var2"],
"opening_statement": 123
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Expected 'opening_statement' to be a str" in str(excinfo.value)

View File

@@ -1,402 +0,0 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
ResponseFormat,
_handle_native_json_schema,
_handle_prompt_based_schema,
_parse_structured_output,
_prepare_schema_for_model,
_set_response_format,
convert_boolean_to_string,
invoke_llm_with_structured_output,
remove_additional_properties,
)
from core.model_manager import ModelInstance
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultWithStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType
class TestStructuredOutput:
def test_remove_additional_properties(self):
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"additionalProperties": False,
"nested": {"type": "object", "additionalProperties": True},
"items": [{"type": "object", "additionalProperties": False}],
}
remove_additional_properties(schema)
assert "additionalProperties" not in schema
assert "additionalProperties" not in schema["nested"]
assert "additionalProperties" not in schema["items"][0]
# Test with non-dict input
remove_additional_properties(None) # Should not raise
remove_additional_properties([]) # Should not raise
def test_convert_boolean_to_string(self):
schema = {
"type": "object",
"properties": {
"is_active": {"type": "boolean"},
"tags": {"type": "array", "items": {"type": "boolean"}},
"list_schema": [{"type": "boolean"}],
},
}
convert_boolean_to_string(schema)
assert schema["properties"]["is_active"]["type"] == "string"
assert schema["properties"]["tags"]["items"]["type"] == "string"
assert schema["properties"]["list_schema"][0]["type"] == "string"
# Test with non-dict input
convert_boolean_to_string(None) # Should not raise
convert_boolean_to_string([]) # Should not raise
def test_parse_structured_output_valid(self):
text = '{"key": "value"}'
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_non_dict_valid_json(self):
# Even if it's valid JSON, if it's not a dict, it should try repair or fail
text = '["a", "b"]'
with patch("json_repair.loads") as mock_repair:
mock_repair.return_value = {"key": "value"}
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_not_dict_fail_via_validate(self):
# Force TypeAdapter to return a non-dict to trigger line 292
with patch("pydantic.TypeAdapter.validate_json") as mock_validate:
mock_validate.return_value = ["a list"]
with pytest.raises(OutputParserError) as excinfo:
_parse_structured_output('["a list"]')
assert "Failed to parse structured output" in str(excinfo.value)
def test_parse_structured_output_repair_success(self):
text = "{'key': 'value'}" # Invalid JSON (single quotes)
# json_repair should handle this
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_repair_list(self):
# Deepseek-r1 case: result is a list containing a dict
text = '[{"key": "value"}]'
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_repair_list_no_dict(self):
# Deepseek-r1 case: result is a list with NO dict
text = "[1, 2, 3]"
assert _parse_structured_output(text) == {}
def test_parse_structured_output_repair_fail(self):
text = "not a json at all"
with patch("json_repair.loads") as mock_repair:
mock_repair.return_value = "still not a dict or list"
with pytest.raises(OutputParserError):
_parse_structured_output(text)
def test_set_response_format(self):
# Test JSON
params = {}
rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON],
)
]
_set_response_format(params, rules)
assert params["response_format"] == ResponseFormat.JSON
# Test JSON_OBJECT
params = {}
rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON_OBJECT],
)
]
_set_response_format(params, rules)
assert params["response_format"] == ResponseFormat.JSON_OBJECT
def test_handle_native_json_schema(self):
provider = "openai"
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gpt-4"
structured_output_schema = {"type": "object"}
model_parameters = {}
rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON_SCHEMA],
)
]
updated_params = _handle_native_json_schema(
provider, model_schema, structured_output_schema, model_parameters, rules
)
assert "json_schema" in updated_params
assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"}
assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA
def test_handle_native_json_schema_no_format_rule(self):
provider = "openai"
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gpt-4"
structured_output_schema = {"type": "object"}
model_parameters = {}
rules = []
updated_params = _handle_native_json_schema(
provider, model_schema, structured_output_schema, model_parameters, rules
)
assert "json_schema" in updated_params
assert "response_format" not in updated_params
def test_handle_prompt_based_schema_with_system_prompt(self):
prompt_messages = [
SystemPromptMessage(content="Existing system prompt"),
UserPromptMessage(content="User question"),
]
schema = {"type": "object"}
result = _handle_prompt_based_schema(prompt_messages, schema)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert "Existing system prompt" in result[0].content
assert json.dumps(schema) in result[0].content
assert isinstance(result[1], UserPromptMessage)
def test_handle_prompt_based_schema_without_system_prompt(self):
prompt_messages = [UserPromptMessage(content="User question")]
schema = {"type": "object"}
result = _handle_prompt_based_schema(prompt_messages, schema)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert json.dumps(schema) in result[0].content
assert isinstance(result[1], UserPromptMessage)
def test_prepare_schema_for_model_gemini(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gemini-1.5-pro"
schema = {"type": "object", "additionalProperties": False}
result = _prepare_schema_for_model("google", model_schema, schema)
assert "additionalProperties" not in result
def test_prepare_schema_for_model_ollama(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "llama3"
schema = {"type": "object"}
result = _prepare_schema_for_model("ollama", model_schema, schema)
assert result == schema
def test_prepare_schema_for_model_default(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gpt-4"
schema = {"type": "object"}
result = _prepare_schema_for_model("openai", model_schema, schema)
assert result == {"schema": schema, "name": "llm_response"}
def test_invoke_llm_with_structured_output_no_stream_native(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = True
model_schema.parameter_rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON_SCHEMA],
)
]
model_schema.model = "gpt-4o"
model_instance = MagicMock(spec=ModelInstance)
mock_result = MagicMock(spec=LLMResult)
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
mock_result.model = "gpt-4o"
mock_result.usage = LLMUsage.empty_usage()
mock_result.system_fingerprint = "fp_native"
mock_result.prompt_messages = [UserPromptMessage(content="hi")]
model_instance.invoke_llm.return_value = mock_result
result = invoke_llm_with_structured_output(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="hi")],
json_schema={"type": "object"},
stream=False,
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"result": "success"}
assert result.system_fingerprint == "fp_native"
def test_invoke_llm_with_structured_output_no_stream_prompt_based(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON],
)
]
model_schema.model = "claude-3"
model_instance = MagicMock(spec=ModelInstance)
mock_result = MagicMock(spec=LLMResult)
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
mock_result.model = "claude-3"
mock_result.usage = LLMUsage.empty_usage()
mock_result.system_fingerprint = "fp_prompt"
mock_result.prompt_messages = []
model_instance.invoke_llm.return_value = mock_result
result = invoke_llm_with_structured_output(
provider="anthropic",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="hi")],
json_schema={"type": "object"},
stream=False,
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"result": "success"}
assert result.system_fingerprint == "fp_prompt"
def test_invoke_llm_with_structured_output_no_string_error(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = []
model_instance = MagicMock(spec=ModelInstance)
mock_result = MagicMock(spec=LLMResult)
mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")])
model_instance.invoke_llm.return_value = mock_result
with pytest.raises(OutputParserError) as excinfo:
invoke_llm_with_structured_output(
provider="anthropic",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[],
json_schema={},
stream=False,
)
assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value)
def test_invoke_llm_with_structured_output_stream(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = []
model_schema.model = "gpt-4"
model_instance = MagicMock(spec=ModelInstance)
# Mock chunks
chunk1 = MagicMock(spec=LLMResultChunk)
chunk1.delta = LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage()
)
chunk1.prompt_messages = [UserPromptMessage(content="hi")]
chunk1.system_fingerprint = "fp1"
chunk2 = MagicMock(spec=LLMResultChunk)
chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}'))
chunk2.prompt_messages = [UserPromptMessage(content="hi")]
chunk2.system_fingerprint = "fp1"
chunk3 = MagicMock(spec=LLMResultChunk)
chunk3.delta = LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data=" "),
]
),
)
chunk3.prompt_messages = [UserPromptMessage(content="hi")]
chunk3.system_fingerprint = "fp1"
event4 = MagicMock()
event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=""))
model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4]
generator = invoke_llm_with_structured_output(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="hi")],
json_schema={},
stream=True,
)
chunks = list(generator)
assert len(chunks) == 5
assert chunks[-1].structured_output == {"key": "value"}
assert chunks[-1].system_fingerprint == "fp1"
assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")]
def test_invoke_llm_with_structured_output_stream_no_id_events(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = []
model_schema.model = "gpt-4"
model_instance = MagicMock(spec=ModelInstance)
model_instance.invoke_llm.return_value = []
generator = invoke_llm_with_structured_output(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[],
json_schema={},
stream=True,
)
with pytest.raises(OutputParserError):
list(generator)
def test_parse_structured_output_empty_string(self):
with pytest.raises(OutputParserError):
_parse_structured_output("")

View File

@@ -1,589 +0,0 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
class TestLLMGenerator:
@pytest.fixture
def mock_model_instance(self):
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
instance = MagicMock()
mock_manager.return_value.get_default_model_instance.return_value = instance
mock_manager.return_value.get_model_instance.return_value = instance
yield instance
@pytest.fixture
def model_config_entity(self):
return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7})
def test_generate_conversation_name_success(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"})
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace:
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "Test Conversation Name"
mock_trace.assert_called_once()
def test_generate_conversation_name_truncated(self, mock_model_instance):
long_query = "a" * 2100
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"})
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", long_query)
assert name == "Short Name"
def test_generate_conversation_name_empty_answer(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = ""
mock_model_instance.invoke_llm.return_value = mock_response
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == ""
def test_generate_conversation_name_json_repair(self, mock_model_instance):
mock_response = MagicMock()
# Invalid JSON that json_repair can fix
mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}"
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "Repaired Name"
def test_generate_conversation_name_not_dict_result(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '["not a dict"]'
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "test query"
def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"something": "else"}'
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "test query"
def test_generate_conversation_name_long_output(self, mock_model_instance):
long_output = "a" * 100
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output})
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert len(name) == 78 # 75 + "..."
assert name.endswith("...")
def test_generate_suggested_questions_after_answer_success(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]'
mock_model_instance.invoke_llm.return_value = mock_response
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert len(questions) == 2
assert questions[0] == "Question 1?"
def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance):
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed")
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert questions == []
def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance):
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert questions == []
def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance):
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert questions == []
def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=True
)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "Generated Prompt"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert result["prompt"] == "Generated Prompt"
assert result["error"] == ""
def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=True
)
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate rule config" in result["error"]
def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=True
)
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate rule config" in result["error"]
assert "Random error" in result["error"]
def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
# Mocking 3 calls for invoke_llm
mock_res1 = MagicMock()
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
mock_res2 = MagicMock()
mock_res2.message.get_text_content.return_value = '"var1", "var2"'
mock_res3 = MagicMock()
mock_res3.message.get_text_content.return_value = "Opening Statement"
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3]
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert result["prompt"] == "Step 1 Prompt"
assert result["variables"] == ["var1", "var2"]
assert result["opening_statement"] == "Opening Statement"
assert result["error"] == ""
def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate prefix prompt" in result["error"]
def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
mock_res1 = MagicMock()
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
# Step 2 fails
mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()]
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate variables" in result["error"]
def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
mock_res1 = MagicMock()
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
mock_res2 = MagicMock()
mock_res2.message.get_text_content.return_value = '"var1"'
# Step 3 fails
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")]
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate conversation opener" in result["error"]
def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
# Mock any step to throw Exception
mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to handle unexpected exception" in result["error"]
assert "Unexpected multi-step error" in result["error"]
def test_generate_code_python_success(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(
instruction="print hello", code_language="python", model_config=model_config_entity
)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "print('hello')"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_code("tenant_id", payload)
assert result["code"] == "print('hello')"
assert result["language"] == "python"
def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(
instruction="console log hello", code_language="javascript", model_config=model_config_entity
)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "console.log('hello')"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_code("tenant_id", payload)
assert result["code"] == "console.log('hello')"
assert result["language"] == "javascript"
def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
result = LLMGenerator.generate_code("tenant_id", payload)
assert "Failed to generate code" in result["error"]
def test_generate_code_exception(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.generate_code("tenant_id", payload)
assert "An unexpected error occurred" in result["error"]
def test_generate_qa_document_success(self, mock_model_instance):
mock_response = MagicMock(spec=LLMResult)
mock_response.message = MagicMock()
mock_response.message.get_text_content.return_value = "QA Document Content"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_qa_document("tenant_id", "query", "English")
assert result == "QA Document Content"
def test_generate_qa_document_type_error(self, mock_model_instance):
mock_model_instance.invoke_llm.return_value = "Not an LLMResult"
with pytest.raises(TypeError, match="Expected LLMResult when stream=False"):
LLMGenerator.generate_qa_document("tenant_id", "query", "English")
def test_generate_structured_output_success(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_structured_output("tenant_id", payload)
parsed_output = json.loads(result["output"])
assert parsed_output["type"] == "object"
assert result["error"] == ""
def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "{'type': 'object'}"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_structured_output("tenant_id", payload)
parsed_output = json.loads(result["output"])
assert parsed_output["type"] == "object"
def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "true" # parsed as bool
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_structured_output("tenant_id", payload)
assert "An unexpected error occurred" in result["error"]
assert "Failed to parse structured output" in result["error"]
def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
result = LLMGenerator.generate_structured_output("tenant_id", payload)
assert "Failed to generate JSON Schema" in result["error"]
def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.generate_structured_output("tenant_id", payload)
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
# Mock __instruction_modify_common call via invoke_llm
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert result == {"modified": "prompt"}
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
last_run = MagicMock()
last_run.query = "q"
last_run.answer = "a"
last_run.error = "e"
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert result == {"modified": "prompt"}
def test_instruction_modify_workflow_app_not_found(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="App not found."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
def test_instruction_modify_workflow_no_workflow(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = None
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service)
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
last_run = MagicMock()
last_run.node_type = "llm"
last_run.status = "s"
last_run.error = "e"
# Return regular values, not Mocks
last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]}
last_run.load_full_inputs.return_value = {"in": "val"}
workflow_service.get_node_last_run.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "workflow"}
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
workflow_service.get_node_last_run.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "fallback"}
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
# Cause exception in node_type logic
workflow.graph_dict = {"graph": {"nodes": []}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
workflow_service.get_node_last_run.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "fallback"}
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
last_run = MagicMock()
last_run.node_type = "llm"
last_run.status = "s"
last_run.error = "e"
# Return regular empty list, not a Mock
last_run.execution_metadata_dict = {"agent_log": []}
last_run.load_full_inputs.return_value = {}
workflow_service.get_node_last_run.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "workflow"}
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
# Testing placeholders replacement via instruction_modify_legacy for convenience
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"ok": true}'
mock_model_instance.invoke_llm.return_value = mock_response
instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}"
LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal"
)
# Verify the call to invoke_llm contains replaced instruction
args, kwargs = mock_model_instance.invoke_llm.call_args
prompt_messages = kwargs["prompt_messages"]
user_msg = prompt_messages[1].content
user_msg_dict = json.loads(user_msg)
assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc.
assert "current_val" in user_msg_dict["instruction"]
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No braces here"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "An unexpected error occurred" in result["error"]
assert "Could not find a valid JSON object" in result["error"]
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
# The exception message is "Expected a JSON object, but got list"
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity):
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
instance = MagicMock()
mock_manager.return_value.get_model_instance.return_value = instance
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"ok": true}'
instance.invoke_llm.return_value = mock_response
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
workflow_service.get_node_last_run.return_value = None
LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "Failed to generate code" in result["error"]
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No JSON here"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "An unexpected error occurred" in result["error"]

View File

@@ -82,68 +82,6 @@ class TestTraceContextFilter:
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == "051581bf3bb55c45"
def test_otel_context_invalid_trace_id(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0
mock_context.is_valid = True
mock_span.get_span_context.return_value = mock_context
# Use mocks for base context to ensure we can test the fallback
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("core.logging.filters.get_trace_id", return_value=""),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == ""
def test_otel_context_invalid_span_id(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0
mock_context.is_valid = True
mock_span.get_span_context.return_value = mock_context
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == ""
def test_otel_context_span_none(self, log_record):
from core.logging.filters import TraceContextFilter
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=None),
mock.patch("core.logging.filters.get_trace_id", return_value=""),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == ""
def test_otel_context_exception(self, log_record):
from core.logging.filters import TraceContextFilter
# Trigger exception in OTEL block
with (
mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception),
mock.patch("core.logging.filters.get_trace_id", return_value=""),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == ""
class TestIdentityContextFilter:
def test_sets_empty_identity_without_request_context(self, log_record):
@@ -176,119 +114,3 @@ class TestIdentityContextFilter:
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""
def test_sets_empty_identity_unauthenticated(self, log_record):
from core.logging.filters import IdentityContextFilter
mock_user = mock.MagicMock()
mock_user.is_authenticated = False
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.user_id == ""
def test_sets_identity_for_account(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockAccount:
pass
mock_user = MockAccount()
mock_user.id = "account_id"
mock_user.current_tenant_id = "tenant_id"
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.Account", MockAccount),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == "tenant_id"
assert log_record.user_id == "account_id"
assert log_record.user_type == "account"
def test_sets_identity_for_account_no_tenant(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockAccount:
pass
mock_user = MockAccount()
mock_user.id = "account_id"
mock_user.current_tenant_id = None
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.Account", MockAccount),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == ""
assert log_record.user_id == "account_id"
assert log_record.user_type == "account"
def test_sets_identity_for_end_user(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockEndUser:
pass
class AnotherClass:
pass
mock_user = MockEndUser()
mock_user.id = "end_user_id"
mock_user.tenant_id = "tenant_id"
mock_user.type = "custom_type"
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.model.EndUser", MockEndUser),
mock.patch("models.Account", AnotherClass),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == "tenant_id"
assert log_record.user_id == "end_user_id"
assert log_record.user_type == "custom_type"
def test_sets_identity_for_end_user_default_type(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockEndUser:
pass
class AnotherClass:
pass
mock_user = MockEndUser()
mock_user.id = "end_user_id"
mock_user.tenant_id = "tenant_id"
mock_user.type = None
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.model.EndUser", MockEndUser),
mock.patch("models.Account", AnotherClass),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == "tenant_id"
assert log_record.user_id == "end_user_id"
assert log_record.user_type == "end_user"

View File

@@ -1,39 +1,27 @@
"""Unit tests for MCP OAuth authentication flow."""
import json
from unittest.mock import Mock, patch
import httpx
import pytest
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity
from core.helper import ssrf_proxy
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_parse_token_response,
_retrieve_redis_state,
auth,
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
check_support_resource_discovery,
client_credentials_flow,
discover_oauth_authorization_server_metadata,
discover_oauth_metadata,
discover_protected_resource_metadata,
exchange_authorization,
generate_pkce_challenge,
get_effective_scope,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
@@ -776,555 +764,3 @@ class TestAuthOrchestration:
auth(mock_provider, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)
def test_generate_pkce_challenge(self):
verifier, challenge = generate_pkce_challenge()
assert verifier
assert challenge
assert "=" not in verifier
assert "=" not in challenge
def test_build_protected_resource_metadata_discovery_urls(self):
# Case 1: WWW-Auth URL provided
urls = build_protected_resource_metadata_discovery_urls(
"https://auth.example.com/prm", "https://api.example.com"
)
assert "https://auth.example.com/prm" in urls
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
# Case 2: No WWW-Auth URL, with path
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1")
assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
# Case 3: No path
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com")
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
def test_build_oauth_authorization_server_metadata_discovery_urls(self):
# Case 1: with auth_server_url
urls = build_oauth_authorization_server_metadata_discovery_urls(
"https://auth.example.com", "https://api.example.com"
)
assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls
assert "https://auth.example.com/.well-known/openid-configuration" in urls
# Case 2: with path
urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant")
assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls
assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls
@patch("core.helper.ssrf_proxy.get")
def test_discover_protected_resource_metadata(self, mock_get):
# Success
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"resource": "https://api.example.com",
"authorization_servers": ["https://auth"],
}
mock_get.return_value = mock_response
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is not None
assert result.resource == "https://api.example.com"
# 404 then Success
res404 = Mock()
res404.status_code = 404
mock_get.side_effect = [res404, mock_response]
result = discover_protected_resource_metadata(None, "https://api.example.com/path")
assert result is not None
assert result.resource == "https://api.example.com"
# Error handling
mock_get.side_effect = httpx.RequestError("Error")
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_authorization_server_metadata(self, mock_get):
# Success
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/auth",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is not None
assert result.authorization_endpoint == "https://auth.example.com/auth"
# 404
res404 = Mock()
res404.status_code = 404
mock_get.side_effect = [res404, mock_response]
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant")
assert result is not None
assert result.authorization_endpoint == "https://auth.example.com/auth"
# ValidationError
mock_response.json.return_value = {"invalid": "data"}
mock_get.side_effect = None
mock_get.return_value = mock_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
def test_get_effective_scope(self):
prm = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=["https://auth"],
scopes_supported=["read", "write"],
)
asm = OAuthMetadata(
authorization_endpoint="https://auth.example.com/auth",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
scopes_supported=["openid", "profile"],
)
# 1. WWW-Auth priority
assert get_effective_scope("scope1", prm, asm, "client") == "scope1"
# 2. PRM priority
assert get_effective_scope(None, prm, asm, "client") == "read write"
# 3. ASM priority
assert get_effective_scope(None, None, asm, "client") == "openid profile"
# 4. Client configured
assert get_effective_scope(None, None, None, "client") == "client"
@patch("core.mcp.auth.auth_flow.redis_client")
def test_redis_state_management(self, mock_redis):
state_data = OAuthCallbackState(
provider_id="p1",
tenant_id="t1",
server_url="https://api",
metadata=None,
client_information=OAuthClientInformation(client_id="c1"),
code_verifier="cv",
redirect_uri="https://re",
)
# Create
state_key = _create_secure_redis_state(state_data)
assert state_key
mock_redis.setex.assert_called_once()
# Retrieve Success
mock_redis.get.return_value = state_data.model_dump_json()
retrieved = _retrieve_redis_state(state_key)
assert retrieved.provider_id == "p1"
mock_redis.delete.assert_called_once()
# Retrieve Failure - Not found
mock_redis.get.return_value = None
with pytest.raises(ValueError, match="expired or does not exist"):
_retrieve_redis_state("absent")
# Retrieve Failure - Invalid JSON
mock_redis.get.return_value = "invalid"
with pytest.raises(ValueError, match="Invalid state parameter"):
_retrieve_redis_state("invalid")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback(self, mock_exchange, mock_retrieve):
state = Mock(spec=OAuthCallbackState)
state.server_url = "https://api"
state.metadata = None
state.client_information = Mock()
state.code_verifier = "cv"
state.redirect_uri = "https://re"
mock_retrieve.return_value = state
tokens = Mock(spec=OAuthTokens)
mock_exchange.return_value = tokens
s, t = handle_callback("key", "code")
assert s == state
assert t == tokens
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery(self, mock_get):
# Case 1: authorization_servers (plural)
res = Mock()
res.status_code = 200
res.json.return_value = {"authorization_servers": ["https://auth1"]}
mock_get.return_value = res
supported, url = check_support_resource_discovery("https://api")
assert supported is True
assert url == "https://auth1"
# Case 2: authorization_server_url (singular alias)
res.json.return_value = {"authorization_server_url": ["https://auth2"]}
supported, url = check_support_resource_discovery("https://api")
assert supported is True
assert url == "https://auth2"
# Case 3: Missing fields
res.json.return_value = {"nothing": []}
supported, url = check_support_resource_discovery("https://api")
assert supported is False
# Case 4: 404
res.status_code = 404
supported, url = check_support_resource_discovery("https://api")
assert supported is False
# Case 5: RequestError
mock_get.side_effect = httpx.RequestError("Error")
supported, url = check_support_resource_discovery("https://api")
assert supported is False
def test_discover_oauth_metadata(self):
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
mock_prm.return_value = ProtectedResourceMetadata(
resource="https://api", authorization_servers=["https://auth"]
)
mock_asm.return_value = Mock(spec=OAuthMetadata)
asm, prm, hint = discover_oauth_metadata("https://api")
assert asm == mock_asm.return_value
assert prm == mock_prm.return_value
mock_asm.assert_called_with("https://auth", "https://api", None)
def test_start_authorization(self):
metadata = OAuthMetadata(
authorization_endpoint="https://auth/authorize",
token_endpoint="https://auth/token",
response_types_supported=["code"],
)
client_info = OAuthClientInformation(client_id="c1")
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create:
mock_create.return_value = "state-key"
# Success with scope
url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read")
assert "scope=read" in url
assert "state=state-key" in url
# Success without metadata
url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1")
assert "https://api/authorize" in url
# Failure: incompatible auth server
metadata.response_types_supported = ["implicit"]
with pytest.raises(ValueError, match="Incompatible auth server"):
start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1")
def test_parse_token_response(self):
# Case 1: JSON
res = Mock()
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
tokens = _parse_token_response(res)
assert tokens.access_token == "at"
# Case 2: Form-urlencoded
res.headers = {"content-type": "application/x-www-form-urlencoded"}
res.text = "access_token=at2&token_type=Bearer"
tokens = _parse_token_response(res)
assert tokens.access_token == "at2"
# Case 3: No content-type, but JSON
res.headers = {}
res.json.return_value = {"access_token": "at3", "token_type": "Bearer"}
tokens = _parse_token_response(res)
assert tokens.access_token == "at3"
# Case 4: No content-type, not JSON, but Form
res.json.side_effect = json.JSONDecodeError("msg", "doc", 0)
res.text = "access_token=at4&token_type=Bearer"
tokens = _parse_token_response(res)
assert tokens.access_token == "at4"
# Case 5: Validation Error fallback
res.json.side_effect = ValidationError.from_exception_data("error", [])
res.text = "access_token=at5&token_type=Bearer"
tokens = _parse_token_response(res)
assert tokens.access_token == "at5"
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization(self, mock_post):
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
metadata = OAuthMetadata(
authorization_endpoint="https://auth/authorize",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
# Success
res = Mock()
res.is_success = True
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
mock_post.return_value = res
tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
assert tokens.access_token == "at"
# Failure: Unsupported grant type
metadata.grant_types_supported = ["client_credentials"]
with pytest.raises(ValueError, match="Incompatible auth server"):
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
# Failure: HTTP error
metadata.grant_types_supported = ["authorization_code"]
res.is_success = False
res.status_code = 400
with pytest.raises(ValueError, match="Token exchange failed"):
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization(self, mock_post):
# Case 1: with client_secret
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
# Success
res = Mock()
res.is_success = True
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"}
mock_post.return_value = res
tokens = refresh_authorization("https://api", None, client_info, "rt")
assert tokens.access_token == "at_new"
assert mock_post.call_args[1]["data"]["client_secret"] == "s1"
# Failure: MaxRetriesExceededError
mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries")
with pytest.raises(MCPRefreshTokenError):
refresh_authorization("https://api", None, client_info, "rt")
# Failure: HTTP error
mock_post.side_effect = None
res.is_success = False
res.text = "error_msg"
with pytest.raises(MCPRefreshTokenError, match="error_msg"):
refresh_authorization("https://api", None, client_info, "rt")
# Failure: Incompatible metadata
metadata = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
with pytest.raises(ValueError, match="Incompatible auth server"):
refresh_authorization("https://api", metadata, client_info, "rt")
@patch("core.helper.ssrf_proxy.post")
def test_client_credentials_flow(self, mock_post):
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
# Success with secret
res = Mock()
res.is_success = True
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"}
mock_post.return_value = res
tokens = client_credentials_flow("https://api", None, client_info, "read")
assert tokens.access_token == "at_cc"
args, kwargs = mock_post.call_args
assert "Authorization" in kwargs["headers"]
# Success without secret
client_info_no_secret = OAuthClientInformation(client_id="c2")
tokens = client_credentials_flow("https://api", None, client_info_no_secret)
args, kwargs = mock_post.call_args
assert kwargs["data"]["client_id"] == "c2"
# Failure: Incompatible metadata
metadata = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
with pytest.raises(ValueError, match="Incompatible auth server"):
client_credentials_flow("https://api", metadata, client_info)
# Failure: HTTP error
res.is_success = False
res.status_code = 401
res.text = "Unauthorized"
with pytest.raises(ValueError, match="Client credentials token request failed"):
client_credentials_flow("https://api", None, client_info)
@patch("core.helper.ssrf_proxy.post")
def test_register_client(self, mock_post):
# Case 1: Success with metadata
metadata = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
registration_endpoint="https://auth/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"])
res = Mock()
res.is_success = True
res.json.return_value = {
"client_id": "c_new",
"client_secret": "s_new",
"client_name": "Dify",
"redirect_uris": ["https://re"],
}
mock_post.return_value = res
info = register_client("https://api", metadata, client_metadata)
assert info.client_id == "c_new"
# Case 2: Success without metadata
info = register_client("https://api", None, client_metadata)
assert mock_post.call_args[0][0] == "https://api/register"
# Case 3: Metadata provided but no endpoint
metadata.registration_endpoint = None
with pytest.raises(ValueError, match="does not support dynamic client registration"):
register_client("https://api", metadata, client_metadata)
# Failure: HTTP
res.is_success = False
res.raise_for_status = Mock()
res.status_code = 400
# If is_success is false, it should call raise_for_status
register_client("https://api", None, client_metadata)
res.raise_for_status.assert_called_once()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_failures(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
# Case 1: No server metadata
mock_discover.return_value = (None, None, None)
with pytest.raises(ValueError, match="Failed to discover OAuth metadata"):
auth(provider)
# Case 2: No client info, exchange code provided
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
)
mock_discover.return_value = (asm, None, None)
provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError, match="Existing OAuth client information is required"):
auth(provider, authorization_code="code")
# Case 3: CLIENT_CREDENTIALS but client must provide info
asm.grant_types_supported = ["client_credentials"]
with pytest.raises(ValueError, match="requires client_id and client_secret"):
auth(provider)
# Case 4: Client registration fails
asm.grant_types_supported = ["authorization_code"]
with patch("core.mcp.auth.auth_flow.register_client") as mock_reg:
mock_reg.side_effect = httpx.RequestError("Reg failed")
with pytest.raises(ValueError, match="Could not register OAuth client"):
auth(provider)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_client_credentials(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1")
provider.decrypt_credentials.return_value = {"scope": "read"}
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["client_credentials"],
)
mock_discover.return_value = (asm, None, None)
with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc:
mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer")
result = auth(provider)
assert result.response == {"result": "success"}
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data["grant_type"] == "client_credentials"
# Failure in CC flow
mock_cc.side_effect = ValueError("CC Failed")
with pytest.raises(ValueError, match="Client credentials flow failed"):
auth(provider)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_authorization_code(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
provider.decrypt_credentials.return_value = {}
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_discover.return_value = (asm, None, None)
# Case 1: Exchange code
with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve:
state = Mock(spec=OAuthCallbackState)
state.code_verifier = "cv"
state.redirect_uri = "https://re"
mock_retrieve.return_value = state
with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange:
mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer")
# Success
result = auth(provider, authorization_code="code", state_param="sp")
assert result.response == {"result": "success"}
# Missing state_param
with pytest.raises(ValueError, match="State parameter is required"):
auth(provider, authorization_code="code")
# Missing verifier in state
state.code_verifier = None
with pytest.raises(ValueError, match="Missing code_verifier"):
auth(provider, authorization_code="code", state_param="sp")
# Invalid state
mock_retrieve.side_effect = ValueError("Invalid")
with pytest.raises(ValueError, match="Invalid state parameter"):
auth(provider, authorization_code="code", state_param="sp")
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_refresh_failure(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
provider.decrypt_credentials.return_value = {}
provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt")
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_discover.return_value = (asm, None, None)
with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh:
mock_refresh.side_effect = ValueError("Refresh Failed")
with pytest.raises(ValueError, match="Could not refresh OAuth tokens"):
auth(provider)

View File

@@ -322,475 +322,3 @@ def test_sse_client_concurrent_access():
assert len(received_messages) == 10
for i in range(10):
assert f"message_{i}" in received_messages
class TestStatusClasses:
"""Tests for _StatusReady and _StatusError data containers."""
def test_status_ready_stores_endpoint(self):
from core.mcp.client.sse_client import _StatusReady
status = _StatusReady("http://example.com/messages/")
assert status.endpoint_url == "http://example.com/messages/"
def test_status_error_stores_exception(self):
from core.mcp.client.sse_client import _StatusError
exc = ValueError("bad endpoint")
status = _StatusError(exc)
assert status.exc is exc
class TestSSETransportInit:
"""Tests for SSETransport default and explicit init values."""
def test_defaults(self):
from core.mcp.client.sse_client import SSETransport
t = SSETransport("http://example.com/sse")
assert t.url == "http://example.com/sse"
assert t.headers == {}
assert t.timeout == 5.0
assert t.sse_read_timeout == 60.0
assert t.endpoint_url is None
assert t.event_source is None
def test_explicit_headers_not_mutated(self):
from core.mcp.client.sse_client import SSETransport
hdrs = {"X-Foo": "bar"}
t = SSETransport("http://example.com/sse", headers=hdrs)
assert t.headers is hdrs
class TestHandleEndpointEvent:
"""Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch."""
def test_invalid_origin_puts_status_error(self):
from core.mcp.client.sse_client import SSETransport, _StatusError
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
# Provide a full URL with a different origin so urljoin keeps it as-is
transport._handle_endpoint_event("http://evil.com/messages/", status_queue)
result = status_queue.get_nowait()
assert isinstance(result, _StatusError)
assert "does not match" in str(result.exc)
def test_valid_origin_puts_status_ready(self):
from core.mcp.client.sse_client import SSETransport, _StatusReady
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
transport._handle_endpoint_event("/messages/?session_id=abc", status_queue)
result = status_queue.get_nowait()
assert isinstance(result, _StatusReady)
assert "example.com" in result.endpoint_url
class TestHandleSSEEvent:
"""Tests for SSETransport._handle_sse_event covering all match branches."""
def _make_sse(self, event_type: str, data: str):
sse = Mock()
sse.event = event_type
sse.data = data
return sse
def test_message_event_dispatched(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue)
item = read_queue.get_nowait()
assert hasattr(item, "message")
def test_unknown_event_logs_warning_and_does_nothing(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue)
assert read_queue.empty()
assert status_queue.empty()
class TestSSEReader:
"""Tests for SSETransport.sse_reader exception branches."""
def test_read_error_closes_cleanly(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
event_source = Mock()
event_source.iter_sse.side_effect = httpx.ReadError("connection reset")
transport.sse_reader(event_source, read_queue, status_queue)
# Finally block always puts None as sentinel
sentinel = read_queue.get_nowait()
assert sentinel is None
def test_generic_exception_puts_exc_then_none(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
boom = RuntimeError("unexpected!")
event_source = Mock()
event_source.iter_sse.side_effect = boom
transport.sse_reader(event_source, read_queue, status_queue)
exc_item = read_queue.get_nowait()
assert exc_item is boom
sentinel = read_queue.get_nowait()
assert sentinel is None
class TestSendMessage:
"""Tests for SSETransport._send_message."""
def _make_session_message(self):
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
msg = types.JSONRPCMessage.model_validate_json(msg_json)
return types.SessionMessage(msg)
def test_sends_post_and_raises_for_status(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
mock_response = Mock()
mock_response.status_code = 200
mock_client = Mock()
mock_client.post.return_value = mock_response
session_msg = self._make_session_message()
transport._send_message(mock_client, "http://example.com/messages/", session_msg)
mock_client.post.assert_called_once()
mock_response.raise_for_status.assert_called_once()
class TestPostWriter:
"""Tests for SSETransport.post_writer exception branches."""
def _make_session_message(self):
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
msg = types.JSONRPCMessage.model_validate_json(msg_json)
return types.SessionMessage(msg)
def test_none_message_exits_loop(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
write_queue.put(None) # Signal shutdown immediately
mock_client = Mock()
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# Should put final None sentinel
sentinel = write_queue.get_nowait()
assert sentinel is None
def test_exception_in_message_put_back_to_queue(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
exc = ValueError("some error")
write_queue.put(exc) # Exception goes in first
write_queue.put(None) # Then shutdown signal
mock_client = Mock()
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# The exception should be re-queued, then None from loop exit, then None from finally
item1 = write_queue.get_nowait()
assert isinstance(item1, Exception)
def test_read_error_shuts_down_cleanly(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
session_msg = self._make_session_message()
write_queue.put(session_msg)
mock_response = Mock()
mock_response.status_code = 200
mock_client = Mock()
mock_client.post.side_effect = httpx.ReadError("connection dropped")
# post_writer calls _send_message which calls client.post → ReadError propagates
# The ReadError is raised inside _send_message → propagates out of the while loop
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# finally always puts None
sentinel = write_queue.get_nowait()
assert sentinel is None
def test_generic_exception_puts_exc_in_queue(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
session_msg = self._make_session_message()
write_queue.put(session_msg)
mock_client = Mock()
boom = RuntimeError("boom")
mock_client.post.side_effect = boom
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
exc_item = write_queue.get_nowait()
assert isinstance(exc_item, Exception)
sentinel = write_queue.get_nowait()
assert sentinel is None
def test_queue_empty_timeout_continues_loop(self):
"""Cover the 'except queue.Empty: continue' branch (line 188) in post_writer."""
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
mock_client = Mock()
# Patch queue.Queue.get so it raises Empty first, then returns None (shutdown)
call_count = {"n": 0}
original_get = write_queue.get
def patched_get(*args, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise queue.Empty
write_queue.get = patched_get # type: ignore[method-assign]
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# finally always puts None sentinel
sentinel = write_queue.get_nowait()
assert sentinel is None
assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries)
class TestWaitForEndpoint:
"""Tests for SSETransport._wait_for_endpoint edge cases."""
def test_raises_on_empty_queue(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue() # empty
with pytest.raises(ValueError, match="failed to get endpoint URL"):
transport._wait_for_endpoint(status_queue)
def test_raises_status_error_exception(self):
from core.mcp.client.sse_client import SSETransport, _StatusError
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
exc = ValueError("malicious endpoint")
status_queue.put(_StatusError(exc))
with pytest.raises(ValueError, match="malicious endpoint"):
transport._wait_for_endpoint(status_queue)
def test_raises_on_unknown_status_type(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
# Put an object that is neither _StatusReady nor _StatusError
status_queue.put("unexpected_value")
with pytest.raises(ValueError, match="failed to get endpoint URL"):
transport._wait_for_endpoint(status_queue)
class TestSSEClientRuntimeError:
"""Test sse_client context manager handles RuntimeError on close()."""
def test_runtime_error_on_close_is_suppressed(self):
"""Ensure RuntimeError raised by event_source.response.close() is caught."""
test_url = "http://test.example/sse"
class MockSSEEvent:
def __init__(self, event_type: str, data: str):
self.event = event_type
self.data = data
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc:
mock_client = Mock()
mock_cf.return_value.__enter__.return_value = mock_client
mock_es = Mock()
mock_es.response.raise_for_status.return_value = None
mock_es.iter_sse.return_value = [endpoint_event]
# Make close() raise RuntimeError to exercise line 307-308
mock_es.response.close.side_effect = RuntimeError("already closed")
mock_sc.return_value.__enter__.return_value = mock_es
# Should NOT raise even though close() raises RuntimeError
with contextlib.suppress(Exception):
with sse_client(test_url) as (rq, wq):
pass
class TestStandaloneSendMessage:
"""Tests for the module-level send_message() function."""
def _make_session_message(self):
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
msg = types.JSONRPCMessage.model_validate_json(msg_json)
return types.SessionMessage(msg)
def test_send_message_success(self):
from core.mcp.client.sse_client import send_message
mock_response = Mock()
mock_response.status_code = 200
mock_http_client = Mock()
mock_http_client.post.return_value = mock_response
session_msg = self._make_session_message()
send_message(mock_http_client, "http://example.com/messages/", session_msg)
mock_http_client.post.assert_called_once()
mock_response.raise_for_status.assert_called_once()
def test_send_message_raises_on_http_error(self):
from core.mcp.client.sse_client import send_message
mock_http_client = Mock()
mock_http_client.post.side_effect = httpx.ConnectError("refused")
session_msg = self._make_session_message()
with pytest.raises(httpx.ConnectError):
send_message(mock_http_client, "http://example.com/messages/", session_msg)
def test_send_message_raises_for_status_failure(self):
from core.mcp.client.sse_client import send_message
mock_response = Mock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Not Found", request=Mock(), response=Mock(status_code=404)
)
mock_http_client = Mock()
mock_http_client.post.return_value = mock_response
session_msg = self._make_session_message()
with pytest.raises(httpx.HTTPStatusError):
send_message(mock_http_client, "http://example.com/messages/", session_msg)
class TestReadMessages:
"""Tests for the module-level read_messages() generator."""
def _make_mock_sse_event(self, event_type: str, data: str):
ev = Mock()
ev.event = event_type
ev.data = data
return ev
def test_valid_message_event_yields_session_message(self):
from core.mcp.client.sse_client import read_messages
valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
mock_sse_event = self._make_mock_sse_event("message", valid_json)
mock_client = Mock()
mock_client.events.return_value = [mock_sse_event]
results = list(read_messages(mock_client))
assert len(results) == 1
assert hasattr(results[0], "message")
def test_invalid_json_yields_exception(self):
from core.mcp.client.sse_client import read_messages
mock_sse_event = self._make_mock_sse_event("message", "{not valid json}")
mock_client = Mock()
mock_client.events.return_value = [mock_sse_event]
results = list(read_messages(mock_client))
assert len(results) == 1
assert isinstance(results[0], Exception)
def test_non_message_event_is_skipped(self):
from core.mcp.client.sse_client import read_messages
mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/")
mock_client = Mock()
mock_client.events.return_value = [mock_sse_event]
results = list(read_messages(mock_client))
# Non-message events produce no output
assert results == []
def test_outer_exception_yields_exc(self):
from core.mcp.client.sse_client import read_messages
boom = RuntimeError("stream broken")
mock_client = Mock()
mock_client.events.side_effect = boom
results = list(read_messages(mock_client))
assert len(results) == 1
assert results[0] is boom
def test_multiple_events_mixed(self):
from core.mcp.client.sse_client import read_messages
valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}'
events = [
self._make_mock_sse_event("endpoint", "/messages/"),
self._make_mock_sse_event("message", valid_json),
self._make_mock_sse_event("message", "{bad json}"),
]
mock_client = Mock()
mock_client.events.return_value = events
results = list(read_messages(mock_client))
# endpoint is skipped; 1 valid SessionMessage + 1 Exception
assert len(results) == 2
assert hasattr(results[0], "message")
assert isinstance(results[1], Exception)

File diff suppressed because it is too large Load Diff

View File

@@ -1,617 +0,0 @@
import queue
import time
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import timedelta
from typing import Union
from unittest.mock import MagicMock, patch
import pytest
from httpx import HTTPStatusError, Request, Response
from pydantic import BaseModel, ConfigDict, RootModel
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.session.base_session import BaseSession, RequestResponder
from core.mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCResponse,
Notification,
RequestParams,
SessionMessage,
)
from core.mcp.types import (
Request as MCPRequest,
)
class MockRequestParams(RequestParams):
name: str = "default"
model_config = ConfigDict(extra="allow")
class MockRequest(MCPRequest[MockRequestParams, str]):
method: str = "test/request"
params: MockRequestParams = MockRequestParams()
class MockResult(BaseModel):
result: str
class MockNotificationParams(BaseModel):
message: str
class MockNotification(Notification[MockNotificationParams, str]):
method: str = "test/notification"
params: MockNotificationParams
class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]):
pass
class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]):
pass
class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.received_requests = []
self.received_notifications = []
self.handled_incoming = []
def _received_request(self, responder):
self.received_requests.append(responder)
def _received_notification(self, notification):
self.received_notifications.append(notification)
def _handle_incoming(self, item):
self.handled_incoming.append(item)
@pytest.fixture
def streams():
return queue.Queue(), queue.Queue()
@pytest.mark.timeout(5)
def test_request_responder_respond(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
on_complete = MagicMock()
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
responder = RequestResponder(
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
)
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
responder.respond(MockResult(result="ok"))
with responder as r:
r.respond(MockResult(result="ok"))
with pytest.raises(AssertionError, match="Request already responded to"):
r.respond(MockResult(result="error"))
assert responder.completed is True
on_complete.assert_called_once_with(responder)
msg = write_stream.get_nowait()
assert isinstance(msg.message.root, JSONRPCResponse)
assert msg.message.root.result == {"result": "ok"}
@pytest.mark.timeout(5)
def test_request_responder_cancel(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
on_complete = MagicMock()
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
responder = RequestResponder(
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
)
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
responder.cancel()
with responder as r:
r.cancel()
assert responder.completed is True
on_complete.assert_called_once_with(responder)
msg = write_stream.get_nowait()
assert isinstance(msg.message.root, JSONRPCError)
assert msg.message.root.error.message == "Request cancelled"
@pytest.mark.timeout(10)
def test_base_session_lifecycle(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session as s:
assert isinstance(s, MockSession)
assert s._executor is not None
assert s._receiver_future is not None
session._receiver_future.result(timeout=5.0)
assert session._receiver_future.done()
@pytest.mark.timeout(5)
def test_send_request_success(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_response():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"})
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
except Exception:
pass
import threading
t = threading.Thread(target=mock_response, daemon=True)
t.start()
with session:
result = session.send_request(request, MockResult)
assert result.result == "hello world"
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_retry_loop_coverage(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_delayed_response():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
time.sleep(0.2)
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"})
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
except:
pass
import threading
t = threading.Thread(target=mock_delayed_response, daemon=True)
t.start()
with session:
result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1))
assert result.result == "slow"
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_jsonrpc_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error"))
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
except:
pass
import threading
t = threading.Thread(target=mock_error, daemon=True)
t.start()
with session:
with pytest.raises(MCPConnectionError) as exc:
session.send_request(request, MockResult)
assert exc.value.args[0].message == "Error"
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_auth_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized"))
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
except:
pass
import threading
t = threading.Thread(target=mock_error, daemon=True)
t.start()
with session:
with pytest.raises(MCPAuthError):
session.send_request(request, MockResult)
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_http_status_error_coverage(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_direct_http_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
# To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError
# DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError.
response = Response(status_code=403, request=Request("GET", "http://test"))
error = HTTPStatusError("Forbidden", request=response.request, response=response)
session._response_streams[req_id].put(error)
except:
pass
import threading
t = threading.Thread(target=mock_direct_http_error, daemon=True)
t.start()
# We still need the session for request ID generation and queue setup
with session:
with pytest.raises(MCPConnectionError) as exc:
session.send_request(request, MockResult)
assert exc.value.args[0].code == 403
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_http_status_auth_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
response = Response(status_code=401, request=Request("GET", "http://test"))
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
read_stream.put(error)
except:
pass
import threading
t = threading.Thread(target=mock_error, daemon=True)
t.start()
with session:
with pytest.raises(MCPAuthError):
session.send_request(request, MockResult)
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_notification(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
notification = MockNotification(method="notify", params=MockNotificationParams(message="hi"))
session.send_notification(notification, related_request_id="rel-1")
msg = write_stream.get_nowait()
assert isinstance(msg.message.root, JSONRPCNotification)
assert msg.message.root.method == "notify"
assert msg.message.root.params == {"message": "hi"}
assert msg.metadata.related_request_id == "rel-1"
@pytest.mark.timeout(10)
def test_receive_loop_request(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
for _ in range(30):
if session.received_requests:
break
time.sleep(0.1)
assert len(session.received_requests) == 1
responder = session.received_requests[0]
assert responder.request_id == 1
assert responder.request.root.method == "test/request"
@pytest.mark.timeout(10)
def test_receive_loop_notification(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
for _ in range(30):
if session.received_notifications:
break
time.sleep(0.1)
assert len(session.received_notifications) == 1
assert isinstance(session.received_notifications[0].root, MockNotification)
assert session.received_notifications[0].root.method == "test/notification"
@pytest.mark.timeout(15)
def test_receive_loop_cancel_notification(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification)
with session:
req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
for _ in range(30):
if "req-1" in session._in_flight:
break
time.sleep(0.1)
assert "req-1" in session._in_flight
responder = session._in_flight["req-1"]
with responder:
cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload)))
for _ in range(30):
if responder.completed:
break
time.sleep(0.1)
assert responder.completed is True
msg = write_stream.get(timeout=2)
assert isinstance(msg.message.root, JSONRPCError)
assert msg.message.root.id == "req-1"
@pytest.mark.timeout(10)
def test_receive_loop_exception(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
read_stream.put(Exception("Unexpected error"))
for _ in range(30):
if any(isinstance(x, Exception) for x in session.handled_incoming):
break
time.sleep(0.1)
assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming)
@pytest.mark.timeout(10)
def test_receive_loop_http_status_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
session._request_id = 1
resp_queue = queue.Queue()
session._response_streams[0] = resp_queue
response = Response(status_code=401, request=Request("GET", "http://test"))
# Using 401 specifically as _receive_loop preserves it
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
read_stream.put(error)
got = resp_queue.get(timeout=2)
assert isinstance(got, HTTPStatusError)
@pytest.mark.timeout(10)
def test_receive_loop_http_status_error_non_401(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
session._request_id = 1
resp_queue = queue.Queue()
session._response_streams[0] = resp_queue
response = Response(status_code=500, request=Request("GET", "http://test"))
error = HTTPStatusError("Server Error", request=response.request, response=response)
read_stream.put(error)
got = resp_queue.get(timeout=2)
assert isinstance(got, JSONRPCError)
assert got.error.code == 500
@pytest.mark.timeout(5)
def test_check_receiver_status_fail(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
executor = ThreadPoolExecutor(max_workers=1)
def raise_err():
raise RuntimeError("Receiver failed")
future = executor.submit(raise_err)
session._receiver_future = future
try:
future.result()
except:
pass
with pytest.raises(RuntimeError, match="Receiver failed"):
session.check_receiver_status()
executor.shutdown()
@pytest.mark.timeout(10)
def test_receive_loop_unknown_request_id(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True})
read_stream.put(SessionMessage(message=JSONRPCMessage(resp)))
for _ in range(30):
if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming):
break
time.sleep(0.1)
assert any("Server Error" in str(x) for x in session.handled_incoming)
@pytest.mark.timeout(10)
def test_receive_loop_http_error_unknown_id(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
response = Response(status_code=401, request=Request("GET", "http://test"))
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
read_stream.put(error)
for _ in range(30):
if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming):
break
time.sleep(0.1)
assert any("unknown request ID" in str(x) for x in session.handled_incoming)
@pytest.mark.timeout(10)
def test_receive_loop_validation_error_notification(streams):
from core.mcp.session.base_session import logger
with patch.object(logger, "warning") as mock_warning:
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification])
with session:
notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
time.sleep(1.0)
assert mock_warning.called
@pytest.mark.timeout(5)
def test_send_request_none_response(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_none():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
session._response_streams[req_id].put(None)
except:
pass
import threading
t = threading.Thread(target=mock_none, daemon=True)
t.start()
with session:
with pytest.raises(MCPConnectionError) as exc:
session.send_request(request, MockResult)
assert exc.value.args[0].message == "No response received"
t.join(timeout=1)
@pytest.mark.timeout(15)
def test_session_exit_timeout(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
mock_future = MagicMock(spec=Future)
mock_future.result.side_effect = TimeoutError()
mock_future.done.return_value = False
session._receiver_future = mock_future
session._executor = MagicMock(spec=ThreadPoolExecutor)
session.__exit__(None, None, None)
mock_future.cancel.assert_called_once()
session._executor.shutdown.assert_called_once_with(wait=False)
@pytest.mark.timeout(10)
def test_receive_loop_fatal_exception(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")):
with patch("core.mcp.session.base_session.logger") as mock_logger:
with pytest.raises(RuntimeError, match="Fatal loop error"):
with session:
pass
mock_logger.exception.assert_called_with("Error in message processing loop")
@pytest.mark.timeout(5)
def test_receive_loop_empty_coverage(streams):
with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
time.sleep(0.3)
@pytest.mark.timeout(2)
def test_base_methods_noop(streams):
read_stream, write_stream = streams
session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
session._received_request(MagicMock())
session._received_notification(MagicMock())
session.send_progress_notification("token", 0.5)
session._handle_incoming(MagicMock())
@pytest.mark.timeout(5)
def test_send_request_session_timeout_retry_6(streams):
read_stream, write_stream = streams
session = MockSession(
read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1)
)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]):
with pytest.raises(RuntimeError, match="timeout_broken"):
session.send_request(request, MockResult)

Some files were not shown because too many files have changed in this diff Show More