mirror of
https://github.com/langgenius/dify.git
synced 2026-03-14 11:47:05 +00:00
Compare commits
39 Commits
dependabot
...
move-knowl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf284f1436 | ||
|
|
a32a79289c | ||
|
|
b348a20d2b | ||
|
|
573b4e41cb | ||
|
|
194c205ed3 | ||
|
|
7e1dc3c122 | ||
|
|
4203647c32 | ||
|
|
20e91990bf | ||
|
|
f38e8cca52 | ||
|
|
00eda73ad1 | ||
|
|
8b40a89add | ||
|
|
97776eabff | ||
|
|
fe561ef3d0 | ||
|
|
1104d35bbb | ||
|
|
724eaee77e | ||
|
|
4717168fe2 | ||
|
|
7fd3bd81ab | ||
|
|
0dcfac5b84 | ||
|
|
b66097b5f3 | ||
|
|
ceaa399351 | ||
|
|
dc50e4c4f2 | ||
|
|
157208ab1e | ||
|
|
3dabdc8282 | ||
|
|
ed5511ce28 | ||
|
|
68982f910e | ||
|
|
c43307dae1 | ||
|
|
b44b37518a | ||
|
|
b170eabaf3 | ||
|
|
e99628b76f | ||
|
|
60fe5e7f00 | ||
|
|
245f6b824d | ||
|
|
7d2054d4f4 | ||
|
|
07e19c0748 | ||
|
|
135b3a15a6 | ||
|
|
0045e387f5 | ||
|
|
44713a5c0f | ||
|
|
d5724aebde | ||
|
|
c59685748c | ||
|
|
36c1f4d506 |
34
.github/actions/setup-web/action.yml
vendored
34
.github/actions/setup-web/action.yml
vendored
@@ -1,33 +1,13 @@
|
||||
name: Setup Web Environment
|
||||
description: Setup pnpm, Node.js, and install web dependencies.
|
||||
|
||||
inputs:
|
||||
node-version:
|
||||
description: Node.js version to use
|
||||
required: false
|
||||
default: "22"
|
||||
install-dependencies:
|
||||
description: Whether to install web dependencies after setting up Node.js
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: ${{ inputs.node-version }}
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: ${{ inputs.install-dependencies == 'true' }}
|
||||
shell: bash
|
||||
run: pnpm --dir web install --frozen-lockfile
|
||||
node-version-file: "./web/.nvmrc"
|
||||
cache: true
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@@ -102,13 +102,11 @@ jobs:
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: "24"
|
||||
|
||||
- name: ESLint autofix
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd web
|
||||
pnpm eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
|
||||
3
.github/workflows/main-ci.yml
vendored
3
.github/workflows/main-ci.yml
vendored
@@ -62,6 +62,9 @@ jobs:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
with:
|
||||
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
|
||||
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint:ci
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
@@ -102,17 +102,17 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint:tss
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
run: vp run knip
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -50,8 +50,6 @@ jobs:
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
install-dependencies: "false"
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
|
||||
24
.github/workflows/web-tests.yml
vendored
24
.github/workflows/web-tests.yml
vendored
@@ -2,6 +2,13 @@ name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
base_sha:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -14,6 +21,8 @@ jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +43,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
@@ -50,6 +59,8 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -59,6 +70,7 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
@@ -72,7 +84,13 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
|
||||
- name: Check app/components diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
@@ -429,4 +447,4 @@ jobs:
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
run: vp run build
|
||||
|
||||
@@ -188,7 +188,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
@@ -102,7 +102,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.entities
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -124,7 +123,6 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
|
||||
@@ -17,11 +17,6 @@ class WeaviateConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
|
||||
@@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
|
||||
@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -29,7 +29,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes import register_core_nodes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -72,6 +74,8 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_core_nodes()
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
def __init__(
|
||||
@@ -490,7 +494,9 @@ class WorkflowBasedAppRunner:
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
retriever_resources=[
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
],
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
||||
@@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
@@ -193,7 +193,8 @@ class LLMGenerator:
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
@@ -279,7 +280,8 @@ class LLMGenerator:
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "handle unexpected exception"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
|
||||
@@ -9,8 +9,8 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .index_processor_factory import IndexProcessorFactory
|
||||
|
||||
@@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.knowledge_retrieval import exc
|
||||
from dify_graph.repositories.rag_retrieval_protocol import (
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
@@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
:param credential_type: the type of the credential, as CredentialType or str; str values
|
||||
are normalized via CredentialType.of and may raise ValueError for invalid values.
|
||||
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
|
||||
empty list for CredentialType.UNAUTHORIZED or missing schemas.
|
||||
|
||||
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
|
||||
Raises ValueError for invalid credential types.
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
|
||||
@@ -137,6 +137,7 @@ class ToolFileManager:
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
@@ -1,4 +1 @@
|
||||
from .node_factory import DifyNodeFactory
|
||||
from .workflow_entry import WorkflowEntry
|
||||
|
||||
__all__ = ["DifyNodeFactory", "WorkflowEntry"]
|
||||
"""Core workflow package."""
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing_extensions import override
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
@@ -17,11 +16,9 @@ from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.nodes import register_core_nodes
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
@@ -53,6 +50,8 @@ if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
register_core_nodes()
|
||||
|
||||
|
||||
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
||||
|
||||
@@ -127,7 +126,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
self._http_request_file_manager = file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
@@ -187,21 +185,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
NodeType.HUMAN_INPUT: lambda: {
|
||||
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: lambda: {
|
||||
"index_processor": IndexProcessor(),
|
||||
"summary_index_service": SummaryIndex(),
|
||||
},
|
||||
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.DATASOURCE: lambda: {
|
||||
"datasource_manager": DatasourceManager,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
|
||||
"rag_retrieval": self._rag_retrieval,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
|
||||
14
api/core/workflow/nodes/__init__.py
Normal file
14
api/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Workflow node implementations that remain under the legacy core.workflow namespace."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def register_core_nodes() -> None:
|
||||
"""Import all core workflow node modules so they self-register with ``Node``."""
|
||||
for _, module_name, _ in pkgutil.walk_packages(__path__, __name__ + "."):
|
||||
importlib.import_module(module_name)
|
||||
1
api/core/workflow/nodes/datasource/__init__.py
Normal file
1
api/core/workflow/nodes/datasource/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Datasource workflow node package."""
|
||||
@@ -1,22 +1,17 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.repositories.datasource_manager_protocol import (
|
||||
DatasourceManagerProtocol,
|
||||
DatasourceParameter,
|
||||
OnlineDriveDownloadFileParam,
|
||||
)
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from .exc import DatasourceNodeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -38,7 +33,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -46,7 +40,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.datasource_manager = datasource_manager
|
||||
self.datasource_manager = DatasourceManager
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
@@ -42,3 +42,14 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
|
||||
|
||||
class DatasourceParameter(BaseModel):
|
||||
workspace_id: str
|
||||
page_id: str
|
||||
type: str
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileParam(BaseModel):
|
||||
id: str
|
||||
bucket: str
|
||||
@@ -1,25 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.file import File
|
||||
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
class DatasourceParameter(BaseModel):
|
||||
workspace_id: str
|
||||
page_id: str
|
||||
type: str
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileParam(BaseModel):
|
||||
id: str
|
||||
bucket: str
|
||||
|
||||
|
||||
class DatasourceFinal(BaseModel):
|
||||
data: dict[str, Any] | None = None
|
||||
from .entities import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
|
||||
|
||||
class DatasourceManagerProtocol(Protocol):
|
||||
1
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
1
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Knowledge index workflow node package."""
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.template import Template
|
||||
from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol
|
||||
from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@@ -34,12 +34,10 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
index_processor: IndexProcessorProtocol,
|
||||
summary_index_service: SummaryIndexServiceProtocol,
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
self.index_processor = index_processor
|
||||
self.summary_index_service = summary_index_service
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = self.node_data
|
||||
@@ -5,21 +5,21 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PreviewItem(BaseModel):
|
||||
content: str | None = Field(None)
|
||||
child_chunks: list[str] | None = Field(None)
|
||||
summary: str | None = Field(None)
|
||||
content: str | None = Field(default=None)
|
||||
child_chunks: list[str] | None = Field(default=None)
|
||||
summary: str | None = Field(default=None)
|
||||
|
||||
|
||||
class QaPreview(BaseModel):
|
||||
answer: str | None = Field(None)
|
||||
question: str | None = Field(None)
|
||||
answer: str | None = Field(default=None)
|
||||
question: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Preview(BaseModel):
|
||||
chunk_structure: str
|
||||
parent_mode: str | None = Field(None)
|
||||
preview: list[PreviewItem] = Field([])
|
||||
qa_preview: list[QaPreview] = Field([])
|
||||
parent_mode: str | None = Field(default=None)
|
||||
preview: list[PreviewItem] = Field(default_factory=list)
|
||||
qa_preview: list[QaPreview] = Field(default_factory=list)
|
||||
total_segments: int
|
||||
|
||||
|
||||
@@ -39,3 +39,9 @@ class IndexProcessorProtocol(Protocol):
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
) -> Preview: ...
|
||||
|
||||
|
||||
class SummaryIndexServiceProtocol(Protocol):
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
) -> None: ...
|
||||
1
api/core/workflow/nodes/knowledge_retrieval/__init__.py
Normal file
1
api/core/workflow/nodes/knowledge_retrieval/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Knowledge retrieval workflow node package."""
|
||||
@@ -1,8 +1,15 @@
|
||||
"""Knowledge retrieval workflow node implementation.
|
||||
|
||||
This node now lives under ``core.workflow.nodes`` and is discovered directly by
|
||||
the workflow node registry.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
@@ -15,7 +22,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
@@ -32,6 +38,7 @@ from .exc import (
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
)
|
||||
from .retrieval import KnowledgeRetrievalRequest, Source
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.file.models import File
|
||||
@@ -53,7 +60,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -63,7 +69,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
@@ -3,9 +3,10 @@ from typing import Any, Literal, Protocol
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.model_runtime.entities import LLMUsage
|
||||
from dify_graph.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
|
||||
from .entities import MetadataFilteringCondition
|
||||
|
||||
|
||||
class SourceChildChunk(BaseModel):
|
||||
id: str = Field(default="", description="Child chunk ID")
|
||||
@@ -28,7 +29,7 @@ class SourceMetadata(BaseModel):
|
||||
segment_id: str | None = Field(default=None, description="Segment unique identifier")
|
||||
retriever_from: str = Field(default="workflow", description="Retriever source context")
|
||||
score: float = Field(default=0.0, description="Retrieval relevance score")
|
||||
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
|
||||
child_chunks: list[SourceChildChunk] = Field(default_factory=list, description="List of child chunks")
|
||||
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
|
||||
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
|
||||
segment_position: int | None = Field(default=0, description="Position of segment in document")
|
||||
@@ -81,28 +82,7 @@ class KnowledgeRetrievalRequest(BaseModel):
|
||||
|
||||
|
||||
class RAGRetrievalProtocol(Protocol):
|
||||
"""Protocol for RAG-based knowledge retrieval implementations.
|
||||
|
||||
Implementations of this protocol handle knowledge retrieval from datasets
|
||||
including rate limiting, dataset filtering, and document retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Return accumulated LLM usage for retrieval operations."""
|
||||
...
|
||||
def llm_usage(self) -> LLMUsage: ...
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
"""Retrieve knowledge from datasets based on the provided request.
|
||||
|
||||
Args:
|
||||
request: Knowledge retrieval request with search parameters
|
||||
|
||||
Returns:
|
||||
List of sources matching the search criteria
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit is exceeded
|
||||
ModelNotExistError: If specified model doesn't exist
|
||||
"""
|
||||
...
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: ...
|
||||
@@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes import register_core_nodes
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -33,6 +34,8 @@ from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_core_nodes()
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
|
||||
@@ -9,10 +9,8 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dify_graph.nodes.answer.answer_node import AnswerNode
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.template import Template
|
||||
from dify_graph.nodes.end.end_node import EndNode
|
||||
from dify_graph.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from dify_graph.runtime.graph_runtime_state import NodeProtocol
|
||||
|
||||
|
||||
@@ -33,10 +31,8 @@ class ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from a response-capable node.
|
||||
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
|
||||
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
|
||||
- `id: str`
|
||||
- `get_streaming_template() -> Template`
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
|
||||
At runtime this must be a response node that exposes `id` and `get_streaming_template()`.
|
||||
|
||||
Args:
|
||||
node: Node from the materialized workflow graph.
|
||||
@@ -47,8 +43,10 @@ class ResponseSession:
|
||||
Raises:
|
||||
TypeError: If node is not a supported response node type.
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
|
||||
if getattr(node, "node_type", None) not in {NodeType.ANSWER, NodeType.END, NodeType.KNOWLEDGE_INDEX}:
|
||||
raise TypeError("ResponseSession.from_node only supports answer, end, or knowledge-index nodes")
|
||||
if not hasattr(node, "get_streaming_template"):
|
||||
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes")
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
@@ -31,7 +31,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
||||
@@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@@ -4,7 +4,8 @@ class InvokeError(ValueError):
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
@@ -282,7 +282,8 @@ class ModelProviderFactory:
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -13,7 +13,7 @@ from .base import NodeEventBase
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(NodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
context_files: list[File] | None = Field(default=None, description="context files")
|
||||
|
||||
|
||||
@@ -179,7 +179,8 @@ class Node(Generic[NodeDataT]):
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under dify_graph.nodes.*
|
||||
# Only register production node implementations defined under the
|
||||
# canonical workflow namespaces.
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
@@ -187,7 +188,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith("dify_graph.nodes."):
|
||||
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
@@ -203,6 +204,7 @@ class Node(Generic[NodeDataT]):
|
||||
else:
|
||||
latest_key = max(version_keys) if version_keys else version
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
Node._registry_version += 1
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
@@ -237,6 +239,11 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||
_registry_version: ClassVar[int] = 0
|
||||
|
||||
@classmethod
|
||||
def get_registry_version(cls) -> int:
|
||||
return cls._registry_version
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -364,20 +371,18 @@ class Node(Generic[NodeDataT]):
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
provider_id = getattr(self.node_data, "provider_id", "")
|
||||
provider_type = getattr(self.node_data, "provider_type", "")
|
||||
if not provider_id:
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
if plugin_id and provider_name:
|
||||
provider_id = f"{plugin_id}/{provider_name}"
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
if provider_id:
|
||||
start_event.provider_id = provider_id
|
||||
if provider_type:
|
||||
start_event.provider_type = str(provider_type)
|
||||
|
||||
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
@@ -523,17 +528,19 @@ class Node(Generic[NodeDataT]):
|
||||
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under dify_graph.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
Import all dify_graph node modules so subclasses register themselves on
|
||||
import. Core workflow nodes are registered by the core workflow layer
|
||||
before it reads the mapping, which keeps the dependency direction
|
||||
pointing from core to dify_graph.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import dify_graph.nodes as _nodes_pkg
|
||||
# Import all node modules to ensure they are loaded (thus registered).
|
||||
import dify_graph.nodes as _dify_nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
for _, module_name, _ in pkgutil.walk_packages(_dify_nodes_pkg.__path__, _dify_nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports.
|
||||
if _modname == "dify_graph.nodes.node_mapping":
|
||||
if module_name == "dify_graph.nodes.node_mapping":
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# Return a readonly view so callers can't mutate the registry by accident
|
||||
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
@@ -72,8 +72,8 @@ class EmailDeliveryConfig(BaseModel):
|
||||
body: str
|
||||
debug_mode: bool = False
|
||||
|
||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
||||
if not user_id:
|
||||
def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig":
|
||||
if user_id is None:
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
||||
@@ -141,7 +141,7 @@ def apply_debug_email_recipient(
|
||||
method: DeliveryChannelConfig,
|
||||
*,
|
||||
enabled: bool,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
) -> DeliveryChannelConfig:
|
||||
if not enabled:
|
||||
return method
|
||||
@@ -149,7 +149,7 @@ def apply_debug_email_recipient(
|
||||
return method
|
||||
if not method.config.debug_mode:
|
||||
return method
|
||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
||||
debug_config = method.config.with_debug_recipient(user_id)
|
||||
return method.model_copy(update={"config": debug_config})
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .knowledge_index_node import KnowledgeIndexNode
|
||||
|
||||
__all__ = ["KnowledgeIndexNode"]
|
||||
@@ -1,3 +0,0 @@
|
||||
from .knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
|
||||
__all__ = ["KnowledgeRetrievalNode"]
|
||||
@@ -17,7 +17,6 @@ from core.llm_generator.output_parser.structured_output import invoke_llm_with_s
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
@@ -677,7 +676,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
elif isinstance(context_value_variable, ArraySegment):
|
||||
context_str = ""
|
||||
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||
original_retriever_resource: list[dict[str, Any]] = []
|
||||
context_files: list[File] = []
|
||||
for item in context_value_variable.value:
|
||||
if isinstance(item, str):
|
||||
@@ -693,11 +692,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
segment_id = retriever_resource.get("segment_id")
|
||||
if not segment_id:
|
||||
continue
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
|
||||
SegmentAttachmentBinding.segment_id == segment_id,
|
||||
)
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
@@ -723,7 +725,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None:
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
@@ -731,28 +733,26 @@ class LLMNode(Node[LLMNodeData]):
|
||||
):
|
||||
metadata = context_dict.get("metadata", {})
|
||||
|
||||
source = RetrievalSourceMetadata(
|
||||
position=metadata.get("position"),
|
||||
dataset_id=metadata.get("dataset_id"),
|
||||
dataset_name=metadata.get("dataset_name"),
|
||||
document_id=metadata.get("document_id"),
|
||||
document_name=metadata.get("document_name"),
|
||||
data_source_type=metadata.get("data_source_type"),
|
||||
segment_id=metadata.get("segment_id"),
|
||||
retriever_from=metadata.get("retriever_from"),
|
||||
score=metadata.get("score"),
|
||||
hit_count=metadata.get("segment_hit_count"),
|
||||
word_count=metadata.get("segment_word_count"),
|
||||
segment_position=metadata.get("segment_position"),
|
||||
index_node_hash=metadata.get("segment_index_node_hash"),
|
||||
content=context_dict.get("content"),
|
||||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
files=context_dict.get("files"),
|
||||
summary=context_dict.get("summary"),
|
||||
)
|
||||
|
||||
return source
|
||||
return {
|
||||
"position": metadata.get("position"),
|
||||
"dataset_id": metadata.get("dataset_id"),
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
"hit_count": metadata.get("segment_hit_count"),
|
||||
"word_count": metadata.get("segment_word_count"),
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
"doc_metadata": metadata.get("doc_metadata"),
|
||||
"files": context_dict.get("files"),
|
||||
"summary": context_dict.get("summary"),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,9 +1,53 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterator, Mapping, MutableMapping
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]):
|
||||
"""Mutable dict-like view over the current node registry."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {}
|
||||
self._cached_version = -1
|
||||
self._deleted: set[NodeType] = set()
|
||||
self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {}
|
||||
|
||||
def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]:
|
||||
current_version = Node.get_registry_version()
|
||||
if self._cached_version != current_version:
|
||||
self._cached_snapshot = dict(Node.get_node_type_classes_mapping())
|
||||
self._cached_version = current_version
|
||||
if not self._deleted and not self._overrides:
|
||||
return self._cached_snapshot
|
||||
|
||||
snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted}
|
||||
snapshot.update(self._overrides)
|
||||
return snapshot
|
||||
|
||||
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
|
||||
return self._snapshot()[key]
|
||||
|
||||
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
|
||||
self._deleted.discard(key)
|
||||
self._overrides[key] = value
|
||||
|
||||
def __delitem__(self, key: NodeType) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._cached_snapshot:
|
||||
self._deleted.add(key)
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[NodeType]:
|
||||
return iter(self._snapshot())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._snapshot())
|
||||
|
||||
|
||||
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class SummaryIndexServiceProtocol(Protocol):
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
): ...
|
||||
@@ -2,8 +2,8 @@ from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import AppDatasetJoin
|
||||
|
||||
@@ -123,7 +123,7 @@ dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
|
||||
@@ -20,9 +20,9 @@ from sqlalchemy.orm import Session
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
|
||||
@@ -187,7 +187,10 @@ class AppService:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
|
||||
typed_tool = {key: value for key, value in tool.items() if isinstance(key, str)}
|
||||
if len(typed_tool) != len(tool):
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity.model_validate(typed_tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
|
||||
@@ -58,8 +58,9 @@ class FileService:
|
||||
# get file extension
|
||||
extension = os.path.splitext(filename)[1].lstrip(".").lower()
|
||||
|
||||
# check if filename contains invalid characters
|
||||
if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]):
|
||||
# Only reject path separators here. The original filename is stored as metadata,
|
||||
# while the storage key is UUID-based.
|
||||
if any(c in filename for c in ["/", "\\"]):
|
||||
raise ValueError("Filename contains invalid characters")
|
||||
|
||||
if len(filename) > 200:
|
||||
|
||||
@@ -36,6 +36,7 @@ from core.rag.entities.event import (
|
||||
)
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes import register_core_nodes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
@@ -86,6 +87,8 @@ from services.workflow_draft_variable_service import DraftVariableSaver, DraftVa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_core_nodes()
|
||||
|
||||
|
||||
class RagPipelineService:
|
||||
def __init__(self, session_maker: sessionmaker | None = None):
|
||||
|
||||
@@ -22,10 +22,10 @@ from sqlalchemy.orm import Session
|
||||
from core.helper import ssrf_proxy
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.nodes.datasource.entities import DatasourceNodeData
|
||||
from dify_graph.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
|
||||
@@ -358,21 +358,19 @@ class WorkflowRunRestore:
|
||||
self,
|
||||
model: type[DeclarativeBase] | Any,
|
||||
) -> tuple[set[str], set[str], set[str]]:
|
||||
columns = list(model.__table__.columns)
|
||||
table = model.__table__
|
||||
columns = list(table.columns)
|
||||
autoincrement_column = getattr(table, "autoincrement_column", None)
|
||||
|
||||
def has_insert_default(column: Any) -> bool:
|
||||
# SQLAlchemy may set column.autoincrement to "auto" on non-PK columns.
|
||||
# Only treat the resolved autoincrement column as DB-generated.
|
||||
return column.default is not None or column.server_default is not None or column is autoincrement_column
|
||||
|
||||
column_names = {column.key for column in columns}
|
||||
required_columns = {
|
||||
column.key
|
||||
for column in columns
|
||||
if not column.nullable
|
||||
and column.default is None
|
||||
and column.server_default is None
|
||||
and not column.autoincrement
|
||||
}
|
||||
required_columns = {column.key for column in columns if not column.nullable and not has_insert_default(column)}
|
||||
non_nullable_with_default = {
|
||||
column.key
|
||||
for column in columns
|
||||
if not column.nullable
|
||||
and (column.default is not None or column.server_default is not None or column.autoincrement)
|
||||
column.key for column in columns if not column.nullable and has_insert_default(column)
|
||||
}
|
||||
return column_names, required_columns, non_nullable_with_default
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.nodes import register_core_nodes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
@@ -70,6 +71,8 @@ from .human_input_delivery_test_service import (
|
||||
)
|
||||
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
register_core_nodes()
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""
|
||||
@@ -952,7 +955,7 @@ class WorkflowService:
|
||||
delivery_method = apply_debug_email_recipient(
|
||||
delivery_method,
|
||||
enabled=True,
|
||||
user_id=account.id or "",
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
variable_pool = self._build_human_input_variable_pool(
|
||||
|
||||
@@ -60,7 +60,6 @@ VECTOR_STORE=weaviate
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
|
||||
class _Seg:
|
||||
@@ -28,13 +29,17 @@ class _GS:
|
||||
|
||||
|
||||
class _GP:
|
||||
tenant_id = "t1"
|
||||
app_id = "app-1"
|
||||
workflow_id = "wf-1"
|
||||
graph_config = {}
|
||||
user_id = "u1"
|
||||
user_from = "account"
|
||||
invoke_from = "debugger"
|
||||
run_context = {
|
||||
DIFY_RUN_CONTEXT_KEY: {
|
||||
"tenant_id": "t1",
|
||||
"app_id": "app-1",
|
||||
"user_id": "u1",
|
||||
"user_from": "account",
|
||||
"invoke_from": "debugger",
|
||||
}
|
||||
}
|
||||
call_depth = 0
|
||||
|
||||
|
||||
@@ -61,6 +66,8 @@ def test_node_integration_minimal_stream(mocker):
|
||||
def get_upload_file_by_id(cls, **_):
|
||||
raise AssertionError
|
||||
|
||||
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
@@ -77,7 +84,6 @@ def test_node_integration_minimal_stream(mocker):
|
||||
},
|
||||
graph_init_params=_GP(),
|
||||
graph_runtime_state=_GS(vp),
|
||||
datasource_manager=_Mgr,
|
||||
)
|
||||
|
||||
out = list(node._run())
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Integration tests for AppModelConfig using testcontainers.
|
||||
|
||||
These tests validate database-backed model behavior without mocking SQLAlchemy queries.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class TestAppModelConfig:
|
||||
"""Integration tests for AppModelConfig."""
|
||||
|
||||
def test_annotation_reply_dict_disabled_without_setting(self, db_session_with_containers: Session) -> None:
|
||||
"""Return disabled annotation reply dict when no AppAnnotationSetting exists."""
|
||||
# Arrange
|
||||
config = AppModelConfig(app_id=str(uuid4()))
|
||||
db_session_with_containers.add(config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
result = config.annotation_reply_dict
|
||||
|
||||
# Assert
|
||||
assert result == {"enabled": False}
|
||||
|
||||
# Cleanup
|
||||
db_session_with_containers.delete(config)
|
||||
db_session_with_containers.commit()
|
||||
@@ -263,6 +263,27 @@ class TestFileService:
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_upload_file_allows_regular_punctuation_in_filename(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload allows punctuation that is safe when stored as metadata.
|
||||
"""
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = 'candidate?resume for "dify"<final>|v2:.txt'
|
||||
content = b"test content"
|
||||
mimetype = "text/plain"
|
||||
|
||||
upload_file = FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file.name == filename
|
||||
|
||||
def test_upload_file_filename_too_long(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
313
api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py
Normal file
313
api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin endpoints
|
||||
|
||||
Tests endpoint structure (method existence) for all plugin APIs, plus
|
||||
handler-level logic tests for representative non-streaming endpoints.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api.plugin.plugin import (
|
||||
PluginFetchAppInfoApi,
|
||||
PluginInvokeAppApi,
|
||||
PluginInvokeEncryptApi,
|
||||
PluginInvokeLLMApi,
|
||||
PluginInvokeLLMWithStructuredOutputApi,
|
||||
PluginInvokeModerationApi,
|
||||
PluginInvokeParameterExtractorNodeApi,
|
||||
PluginInvokeQuestionClassifierNodeApi,
|
||||
PluginInvokeRerankApi,
|
||||
PluginInvokeSpeech2TextApi,
|
||||
PluginInvokeSummaryApi,
|
||||
PluginInvokeTextEmbeddingApi,
|
||||
PluginInvokeToolApi,
|
||||
PluginInvokeTTSApi,
|
||||
PluginUploadFileRequestApi,
|
||||
)
|
||||
|
||||
|
||||
def _extract_raw_post(cls):
|
||||
"""Extract the raw post() method from a plugin endpoint class.
|
||||
|
||||
Plugin endpoint methods are wrapped by several decorators (get_user_tenant,
|
||||
setup_required, plugin_inner_api_only, plugin_data). These decorators
|
||||
use @wraps where possible. This helper ensures we retrieve the original
|
||||
post(self, user_model, tenant_model, payload) function by unwrapping
|
||||
and, if necessary, walking the closure of the innermost wrapper.
|
||||
"""
|
||||
bottom = inspect.unwrap(cls.post)
|
||||
|
||||
# If unwrap() didn't get us to the raw function (e.g. if a decorator
|
||||
# missed @wraps), try to extract it from the closure if it looks like
|
||||
# a plugin_data or similar wrapper that closes over 'view_func'.
|
||||
if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars:
|
||||
try:
|
||||
idx = bottom.__code__.co_freevars.index("view_func")
|
||||
return bottom.__closure__[idx].cell_contents
|
||||
except (AttributeError, TypeError, IndexError):
|
||||
pass
|
||||
|
||||
return bottom
|
||||
|
||||
|
||||
class TestPluginInvokeLLMApi:
|
||||
"""Test PluginInvokeLLMApi endpoint structure"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeLLMWithStructuredOutputApi:
|
||||
"""Test PluginInvokeLLMWithStructuredOutputApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMWithStructuredOutputApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTextEmbeddingApi:
|
||||
"""Test PluginInvokeTextEmbeddingApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTextEmbeddingApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeRerankApi:
|
||||
"""Test PluginInvokeRerankApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeRerankApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTTSApi:
|
||||
"""Test PluginInvokeTTSApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTTSApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeSpeech2TextApi:
|
||||
"""Test PluginInvokeSpeech2TextApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSpeech2TextApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeModerationApi:
|
||||
"""Test PluginInvokeModerationApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeModerationApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeToolApi:
|
||||
"""Test PluginInvokeToolApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeToolApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeParameterExtractorNodeApi:
|
||||
"""Test PluginInvokeParameterExtractorNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeParameterExtractorNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeQuestionClassifierNodeApi:
|
||||
"""Test PluginInvokeQuestionClassifierNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeQuestionClassifierNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeAppApi:
|
||||
"""Test PluginInvokeAppApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeAppApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeEncryptApi:
|
||||
"""Test PluginInvokeEncryptApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeEncryptApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() delegates to PluginEncrypter and returns model_dump output"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act — extract raw post() bypassing all decorators including plugin_data
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload)
|
||||
assert result["data"] == {"encrypted": "data"}
|
||||
assert result.get("error") == ""
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() catches exceptions and returns error response"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed")
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
assert "encrypt failed" in result["error"]
|
||||
|
||||
|
||||
class TestPluginInvokeSummaryApi:
|
||||
"""Test PluginInvokeSummaryApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSummaryApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginUploadFileRequestApi:
|
||||
"""Test PluginUploadFileRequestApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginUploadFileRequestApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin")
|
||||
def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask):
|
||||
"""Test that post() generates a signed URL and returns it"""
|
||||
# Arrange
|
||||
mock_get_url.return_value = "https://storage.example.com/signed-upload-url"
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-id"
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.filename = "test.pdf"
|
||||
mock_payload.mimetype = "application/pdf"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginUploadFileRequestApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_get_url.assert_called_once_with(
|
||||
filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id"
|
||||
)
|
||||
assert result["data"]["url"] == "https://storage.example.com/signed-upload-url"
|
||||
|
||||
|
||||
class TestPluginFetchAppInfoApi:
|
||||
"""Test PluginFetchAppInfoApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginFetchAppInfoApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation")
|
||||
def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask):
|
||||
"""Test that post() fetches app info and returns it"""
|
||||
# Arrange
|
||||
mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.app_id = "app-123"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginFetchAppInfoApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id")
|
||||
assert result["data"] == {"app_name": "My App", "mode": "chat"}
|
||||
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.plugin.wraps import (
|
||||
TenantUserPayload,
|
||||
get_user,
|
||||
get_user_tenant,
|
||||
plugin_data,
|
||||
)
|
||||
|
||||
|
||||
class TestTenantUserPayload:
|
||||
"""Test TenantUserPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload passes validation"""
|
||||
data = {"tenant_id": "tenant123", "user_id": "user456"}
|
||||
payload = TenantUserPayload.model_validate(data)
|
||||
assert payload.tenant_id == "tenant123"
|
||||
assert payload.user_id == "user456"
|
||||
|
||||
def test_missing_tenant_id(self):
|
||||
"""Test missing tenant_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"user_id": "user456"})
|
||||
|
||||
def test_missing_user_id(self):
|
||||
"""Test missing user_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"tenant_id": "tenant123"})
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test get_user function"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test returning existing user when found by ID"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.query.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_anonymous_user_by_session_id(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test returning existing anonymous user by session_id"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "anonymous_session")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test creating new user when not found in database"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_new_user
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_use_default_session_id_when_user_id_none(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test using default session ID when user_id is None"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", None)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_raise_error_on_database_exception(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test raising ValueError when database operation fails"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
with pytest.raises(ValueError, match="user not found"):
|
||||
get_user("tenant123", "user123")
|
||||
|
||||
|
||||
class TestGetUserTenant:
|
||||
"""Test get_user_tenant decorator"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that decorator injects tenant_model and user_model into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user456"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
|
||||
def test_should_raise_error_when_tenant_id_missing(self, app: Flask):
|
||||
"""Test that Pydantic ValidationError is raised when tenant_id is missing from payload"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert - Pydantic validates payload before manual check
|
||||
with app.test_request_context(json={"user_id": "user456"}):
|
||||
with pytest.raises(ValidationError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_when_tenant_not_found(self, app: Flask):
|
||||
"""Test that ValueError is raised when tenant is not found"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="tenant not found"):
|
||||
protected_view()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that default session ID is used when user_id is empty string"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
|
||||
# Act - use empty string for user_id to trigger default logic
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
from models.model import DefaultEndUserSessionID
|
||||
|
||||
mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID)
|
||||
|
||||
|
||||
class PluginTestPayload:
|
||||
"""Simple test payload class"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.value = data.get("value")
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
return cls(data)
|
||||
|
||||
|
||||
class TestPluginData:
|
||||
"""Test plugin_data decorator"""
|
||||
|
||||
def test_should_inject_valid_payload(self, app: Flask):
|
||||
"""Test that valid payload is injected into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "test_data"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "test_data"
|
||||
|
||||
def test_should_raise_error_on_invalid_json(self, app: Flask):
|
||||
"""Test that ValueError is raised when JSON parsing fails"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert - Malformed JSON triggers ValueError
|
||||
with app.test_request_context(data="not valid json", content_type="application/json"):
|
||||
with pytest.raises(ValueError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_on_invalid_payload(self, app: Flask):
|
||||
"""Test that ValueError is raised when payload validation fails"""
|
||||
|
||||
# Arrange
|
||||
class InvalidPayload:
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
raise Exception("Validation failed")
|
||||
|
||||
@plugin_data(payload_type=InvalidPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"data": "test"}):
|
||||
with pytest.raises(ValueError, match="invalid payload"):
|
||||
protected_view()
|
||||
|
||||
def test_should_work_as_parameterized_decorator(self, app: Flask):
|
||||
"""Test that decorator works when used with parentheses"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "parameterized"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "parameterized"
|
||||
309
api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py
Normal file
309
api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Unit tests for inner_api auth decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.inner_api.wraps import (
|
||||
billing_inner_api_only,
|
||||
enterprise_inner_api_only,
|
||||
enterprise_inner_api_user_auth,
|
||||
plugin_inner_api_only,
|
||||
)
|
||||
|
||||
|
||||
class TestBillingInnerApiOnly:
|
||||
"""Test billing_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiOnly:
|
||||
"""Test enterprise_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiUserAuth:
|
||||
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
|
||||
|
||||
def test_should_pass_through_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that request passes through when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_header_missing(self, app: Flask):
|
||||
"""Test that request passes through when Authorization header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_format_invalid(self, app: Flask):
|
||||
"""Test that request passes through when Authorization format is invalid (no colon)"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"Authorization": "invalid_format"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask):
|
||||
"""Test that request passes through when HMAC signature is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act - use wrong signature
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_inject_user_when_hmac_signature_valid(self, app: Flask):
|
||||
"""Test that user is injected when HMAC signature is valid"""
|
||||
# Arrange
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user")
|
||||
|
||||
# Calculate valid HMAC signature
|
||||
user_id = "user123"
|
||||
inner_api_key = "valid_key"
|
||||
data_to_sign = f"DIFY {user_id}"
|
||||
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
|
||||
valid_signature = b64encode(signature.digest()).decode("utf-8")
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
class TestPluginInnerApiOnly:
|
||||
"""Test plugin_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when PLUGIN_DAEMON_KEY is set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask):
|
||||
"""Test that 404 is returned when PLUGIN_DAEMON_KEY is not set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_404_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
206
api/tests/unit_tests/controllers/inner_api/test_mail.py
Normal file
206
api/tests/unit_tests/controllers/inner_api/test_mail.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Unit tests for inner_api mail module
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.mail import (
|
||||
BaseMail,
|
||||
BillingMail,
|
||||
EnterpriseMail,
|
||||
InnerMailPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestInnerMailPayload:
|
||||
"""Test InnerMailPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload_with_all_fields(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
"substitutions": {"key": "value"},
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions == {"key": "value"}
|
||||
|
||||
def test_valid_payload_without_substitutions(self):
|
||||
"""Test valid payload without optional substitutions"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions is None
|
||||
|
||||
def test_empty_to_list_fails_validation(self):
|
||||
"""Test that empty 'to' list fails validation due to min_length=1"""
|
||||
data = {
|
||||
"to": [],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_multiple_recipients_allowed(self):
|
||||
"""Test that multiple recipients are allowed"""
|
||||
data = {
|
||||
"to": ["user1@example.com", "user2@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert len(payload.to) == 2
|
||||
assert "user1@example.com" in payload.to
|
||||
assert "user2@example.com" in payload.to
|
||||
|
||||
def test_missing_to_field_fails_validation(self):
|
||||
"""Test that missing 'to' field fails validation"""
|
||||
data = {
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_subject_fails_validation(self):
|
||||
"""Test that missing 'subject' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_body_fails_validation(self):
|
||||
"""Test that missing 'body' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
|
||||
class TestBaseMail:
|
||||
"""Test BaseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BaseMail API instance"""
|
||||
return BaseMail()
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_sends_email_task(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends inner email task"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
json={
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
):
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Test Subject",
|
||||
body="Test Body",
|
||||
substitutions=None,
|
||||
)
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_with_substitutions(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends email with substitutions"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Hello {{name}}",
|
||||
"body": "Welcome {{name}}!",
|
||||
"substitutions": {"name": "John"},
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Hello {{name}}",
|
||||
body="Welcome {{name}}!",
|
||||
substitutions={"name": "John"},
|
||||
)
|
||||
|
||||
|
||||
class TestEnterpriseMail:
|
||||
"""Test EnterpriseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create EnterpriseMail API instance"""
|
||||
return EnterpriseMail()
|
||||
|
||||
def test_has_enterprise_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has enterprise_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
|
||||
assert enterprise_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
|
||||
|
||||
class TestBillingMail:
|
||||
"""Test BillingMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BillingMail API instance"""
|
||||
return BillingMail()
|
||||
|
||||
def test_has_billing_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that BillingMail has billing_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import billing_inner_api_only
|
||||
|
||||
assert billing_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that BillingMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Unit tests for inner_api workspace module
|
||||
|
||||
Tests Pydantic model validation and endpoint handler logic.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them and focus on business logic.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.workspace.workspace import (
|
||||
EnterpriseWorkspace,
|
||||
EnterpriseWorkspaceNoOwnerEmail,
|
||||
WorkspaceCreatePayload,
|
||||
WorkspaceOwnerlessPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspaceCreatePayload:
|
||||
"""Test WorkspaceCreatePayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"name": "My Workspace",
|
||||
"owner_email": "owner@example.com",
|
||||
}
|
||||
payload = WorkspaceCreatePayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
assert payload.owner_email == "owner@example.com"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {"owner_email": "owner@example.com"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
def test_missing_owner_email_fails_validation(self):
|
||||
"""Test that missing owner_email fails validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "owner_email" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestWorkspaceOwnerlessPayload:
|
||||
"""Test WorkspaceOwnerlessPayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with name passes validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
payload = WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspace:
|
||||
"""Test EnterpriseWorkspace API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspace()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that EnterpriseWorkspace has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace and assigns the owner account"""
|
||||
# Arrange
|
||||
mock_account = MagicMock()
|
||||
mock_account.email = "owner@example.com"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["name"] == "My Workspace"
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
|
||||
"""Test that post() returns 404 when the owner account does not exist"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "owner account not found."}, 404)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspaceNoOwnerEmail:
|
||||
"""Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspaceNoOwnerEmail()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace without an owner and returns expected fields"""
|
||||
# Arrange
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.encrypt_public_key = "pub-key"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.custom_config = None
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["encrypt_public_key"] == "pub-key"
|
||||
assert result["tenant"]["custom_config"] == {}
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
80
api/tests/unit_tests/core/agent/conftest.py
Normal file
80
api/tests/unit_tests/core/agent/conftest.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class DummyPromptEntity:
|
||||
def __init__(self, first_prompt):
|
||||
self.first_prompt = first_prompt
|
||||
|
||||
|
||||
class DummyAgentConfig:
|
||||
def __init__(self, prompt_entity=None):
|
||||
self.prompt = prompt_entity
|
||||
|
||||
|
||||
class DummyAppConfig:
|
||||
def __init__(self, agent=None):
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class DummyScratchpadUnit:
|
||||
def __init__(
|
||||
self,
|
||||
final=False,
|
||||
thought=None,
|
||||
action_str=None,
|
||||
observation=None,
|
||||
agent_response=None,
|
||||
):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tool_factory():
|
||||
def _factory(name):
|
||||
return DummyTool(name)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_prompt_entity_factory():
|
||||
def _factory(first_prompt):
|
||||
return DummyPromptEntity(first_prompt)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent_config_factory():
|
||||
def _factory(prompt_entity=None):
|
||||
return DummyAgentConfig(prompt_entity)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_app_config_factory():
|
||||
def _factory(agent=None):
|
||||
return DummyAppConfig(agent)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_scratchpad_unit_factory():
|
||||
def _factory(**kwargs):
|
||||
return DummyScratchpadUnit(**kwargs)
|
||||
|
||||
return _factory
|
||||
@@ -1,70 +1,255 @@
|
||||
"""Unit tests for CotAgentOutputParser.
|
||||
|
||||
Verifies expected parsing behavior for streaming content and JSON payloads,
|
||||
including edge cases such as empty/non-string content and malformed JSON.
|
||||
Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real
|
||||
model output structures. Implementation under test:
|
||||
core.agent.output_parser.cot_output_parser.CotAgentOutputParser.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
|
||||
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
|
||||
for i in range(len(text)):
|
||||
yield LLMResultChunk(
|
||||
model="model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
|
||||
@pytest.fixture
|
||||
def mock_action_class(mocker):
|
||||
mock_action = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
|
||||
mock_action,
|
||||
)
|
||||
return mock_action
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usage_dict():
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_chunk():
|
||||
def _make_chunk(content=None, usage=None):
|
||||
delta = SimpleNamespace(
|
||||
message=SimpleNamespace(content=content),
|
||||
usage=usage,
|
||||
)
|
||||
return SimpleNamespace(delta=delta)
|
||||
|
||||
return _make_chunk
|
||||
|
||||
|
||||
def test_cot_output_parser():
|
||||
test_cases = [
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with json
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with JSON
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# list
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block
|
||||
{
|
||||
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block and json
|
||||
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
|
||||
]
|
||||
# ============================================================
|
||||
# Test Suite
|
||||
# ============================================================
|
||||
|
||||
parser = CotAgentOutputParser()
|
||||
usage_dict = {}
|
||||
for test_case in test_cases:
|
||||
# mock llm_response as a generator by text
|
||||
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
|
||||
results = parser.handle_react_stream_output(llm_response, usage_dict)
|
||||
output = ""
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
|
||||
class TestCotAgentOutputParser:
|
||||
"""Validate CotAgentOutputParser streaming + JSON parsing behavior.
|
||||
|
||||
Lifecycle: no explicit setup/teardown; relies on pytest fixtures for
|
||||
lightweight chunk/action doubles. Invariants: non-string/empty content
|
||||
yields no output, usage gets recorded when provided, and valid action JSON
|
||||
results in Action instantiation. Usage: invoke via pytest (e.g.,
|
||||
`pytest -k TestCotAgentOutputParser`).
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Basic streaming & usage
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_stream_plain_text(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("hello world")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "".join(result) == "hello world"
|
||||
|
||||
def test_stream_empty_string(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_stream_none_content(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk(None)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.parametrize("content", [123, 12.5, [], {}, object()])
|
||||
def test_non_string_content(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_usage_update(self, make_chunk, usage_dict) -> None:
|
||||
usage_data = {"tokens": 99}
|
||||
chunks = [make_chunk("abc", usage=usage_data)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert usage_dict["usage"] == usage_data
|
||||
|
||||
# --------------------------------------------------------
|
||||
# JSON parsing (direct + streaming)
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '{"action": "search", "input": "query"}'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="search", action_input="query")
|
||||
|
||||
def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '[{"action": "lookup", "input": "abc"}]'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None:
|
||||
content = '{"foo": "bar"}'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# Expect the serialized JSON to be yielded as a single element.
|
||||
assert result == [json.dumps({"foo": "bar"})]
|
||||
|
||||
def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None:
|
||||
content = "{invalid json}"
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert any("invalid json" in str(r) for r in result)
|
||||
|
||||
def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
chunks = [
|
||||
make_chunk('{"action": '),
|
||||
make_chunk('"multi", '),
|
||||
make_chunk('"input": "step"}'),
|
||||
]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="multi", action_input="step")
|
||||
|
||||
def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('{"foo": "bar"')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('{"foo": "bar"' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Code block JSON extraction
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = """```json
|
||||
{"action": "lookup", "input": "abc"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
# Multiple JSON objects inside single code fence (invalid combined JSON)
|
||||
# Parser should safely ignore invalid combined block
|
||||
content = """```json
|
||||
{"action": "a1", "input": "x"}
|
||||
{"action": "a2", "input": "y"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# No valid parsed action expected due to invalid combined JSON
|
||||
assert mock_action_class.call_count == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None:
|
||||
content = """```json
|
||||
{invalid}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result
|
||||
|
||||
def test_unclosed_code_block(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('```json {"a":1}')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('```json {"a":1}' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Action / Thought prefix handling
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
" action: something",
|
||||
" ACTION: something",
|
||||
" thought: reasoning",
|
||||
" THOUGHT: reasoning",
|
||||
],
|
||||
)
|
||||
def test_prefix_handling(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
joined = "".join(str(item) for item in result)
|
||||
expected_word = "something" if "action:" in content.lower() else "reasoning"
|
||||
assert expected_word in joined
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("xaction: test")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "x" in "".join(map(str, result))
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Mixed streaming scenarios
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = 'start {"action": "mix", "input": "1"} end'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# JSON action should be parsed
|
||||
mock_action_class.assert_called_once()
|
||||
# Ensure surrounding text is streamed (character-level)
|
||||
joined = "".join(str(r) for r in result if not isinstance(r, MagicMock))
|
||||
assert "start" in joined
|
||||
assert "end" in joined
|
||||
|
||||
def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert mock_action_class.call_count == 2
|
||||
|
||||
def test_backtick_noise(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("text with ` random ` backticks")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "text with" in "".join(result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Boundary & edge inputs
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
"```",
|
||||
"{",
|
||||
"}",
|
||||
"```json",
|
||||
"action:",
|
||||
"thought:",
|
||||
" ",
|
||||
],
|
||||
)
|
||||
def test_edge_inputs(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
joined = "".join(result)
|
||||
if content == " ":
|
||||
assert result == [] or joined == content
|
||||
if content in {"```", "{", "}", "```json"}:
|
||||
assert content in joined
|
||||
if content.lower() in {"action:", "thought:"}:
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
174
api/tests/unit_tests/core/agent/strategy/test_base.py
Normal file
174
api/tests/unit_tests/core/agent/strategy/test_base.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
|
||||
|
||||
class DummyStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Concrete implementation for testing BaseAgentStrategy
|
||||
"""
|
||||
|
||||
def __init__(self, return_values=None, raise_exception=None):
|
||||
self.return_values = return_values or []
|
||||
self.raise_exception = raise_exception
|
||||
self.received_args = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id=None,
|
||||
app_id=None,
|
||||
message_id=None,
|
||||
credentials=None,
|
||||
) -> Generator:
|
||||
self.received_args = (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
credentials,
|
||||
)
|
||||
|
||||
if self.raise_exception:
|
||||
raise self.raise_exception
|
||||
|
||||
yield from self.return_values
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInstantiation:
|
||||
def test_cannot_instantiate_abstract_class(self) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
BaseAgentStrategy()
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInvoke:
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
return MagicMock(name="AgentInvokeMessage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
return MagicMock(name="InvokeCredentials")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("params", "user_id", "conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
({"key": "value"}, "user1", "conv1", "app1", "msg1"),
|
||||
({}, "user2", None, None, None),
|
||||
({"a": 1}, "", "", "", ""),
|
||||
({"nested": {"x": 1}}, "user3", None, "app3", None),
|
||||
],
|
||||
)
|
||||
def test_invoke_success(
|
||||
self,
|
||||
mock_message,
|
||||
mock_credentials,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[mock_message])
|
||||
|
||||
# Act
|
||||
result = list(
|
||||
strategy.invoke(
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
credentials=mock_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == [mock_message]
|
||||
assert strategy.received_args == (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
mock_credentials,
|
||||
)
|
||||
|
||||
def test_invoke_multiple_yields(self, mock_message) -> None:
|
||||
# Arrange
|
||||
messages = [mock_message, MagicMock(), MagicMock()]
|
||||
strategy = DummyStrategy(return_values=messages)
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == messages
|
||||
|
||||
def test_invoke_empty_generator(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_invoke_propagates_exception(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(raise_exception=ValueError("failure"))
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="failure"):
|
||||
list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_params",
|
||||
[
|
||||
None,
|
||||
"",
|
||||
123,
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None:
|
||||
"""
|
||||
Base class does not validate types — ensure pass-through behavior
|
||||
"""
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params=invalid_params, user_id="user"))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_invoke_none_user_id(self) -> None:
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params={}, user_id=None))
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestBaseAgentStrategyGetParameters:
|
||||
def test_get_parameters_default_empty_list(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
result = strategy.get_parameters()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result == []
|
||||
|
||||
def test_get_parameters_returns_new_list_each_time(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
|
||||
first = strategy.get_parameters()
|
||||
second = strategy.get_parameters()
|
||||
|
||||
assert first == second == []
|
||||
assert first is not second
|
||||
272
api/tests/unit_tests/core/agent/strategy/test_plugin.py
Normal file
272
api/tests/unit_tests/core/agent/strategy/test_plugin.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# File: tests/unit_tests/core/agent/strategy/test_plugin.py
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
|
||||
# ============================================================
|
||||
# Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parameter():
|
||||
def _factory(name="param", return_value="initialized"):
|
||||
param = MagicMock()
|
||||
param.name = name
|
||||
param.init_frontend_parameter = MagicMock(return_value=return_value)
|
||||
return param
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_declaration(mock_parameter):
|
||||
param1 = mock_parameter("param1", "init1")
|
||||
param2 = mock_parameter("param2", "init2")
|
||||
|
||||
identity = MagicMock()
|
||||
identity.provider = "provider_x"
|
||||
identity.name = "strategy_x"
|
||||
|
||||
declaration = MagicMock()
|
||||
declaration.parameters = [param1, param2]
|
||||
declaration.identity = identity
|
||||
|
||||
return declaration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(mock_declaration):
|
||||
return PluginAgentStrategy(
|
||||
tenant_id="tenant_123",
|
||||
declaration=mock_declaration,
|
||||
meta_version="v1",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Initialization Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestPluginAgentStrategyInitialization:
|
||||
def test_init_sets_attributes(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version="meta_v",
|
||||
)
|
||||
|
||||
assert strategy.tenant_id == "tenant_test"
|
||||
assert strategy.declaration == mock_declaration
|
||||
assert strategy.meta_version == "meta_v"
|
||||
|
||||
def test_init_meta_version_none(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version=None,
|
||||
)
|
||||
|
||||
assert strategy.meta_version is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetParameters:
|
||||
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
|
||||
result = strategy.get_parameters()
|
||||
assert result == mock_declaration.parameters
|
||||
|
||||
|
||||
# ============================================================
|
||||
# initialize_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInitializeParameters:
|
||||
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
|
||||
params = {"param1": "value1"}
|
||||
|
||||
result = strategy.initialize_parameters(params.copy())
|
||||
|
||||
assert result["param1"] == "init1"
|
||||
assert result["param2"] == "init2"
|
||||
|
||||
mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1")
|
||||
mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_params",
|
||||
[
|
||||
{},
|
||||
{"param1": None},
|
||||
{"param1": ""},
|
||||
{"param1": 0},
|
||||
{"param1": []},
|
||||
{"param1": {}, "param2": "value"},
|
||||
],
|
||||
)
|
||||
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
|
||||
result = strategy.initialize_parameters(input_params.copy())
|
||||
|
||||
for param in strategy.declaration.parameters:
|
||||
assert param.name in result
|
||||
|
||||
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.initialize_parameters(None)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# _invoke Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInvoke:
|
||||
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mock_convert = mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={"converted": True},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={"param1": "value"},
|
||||
user_id="user_1",
|
||||
conversation_id="conv_1",
|
||||
app_id="app_1",
|
||||
message_id="msg_1",
|
||||
credentials=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["msg1", "msg2"]
|
||||
mock_convert.assert_called_once()
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
call_kwargs = mock_manager.invoke.call_args.kwargs
|
||||
assert call_kwargs["tenant_id"] == "tenant_123"
|
||||
assert call_kwargs["user_id"] == "user_1"
|
||||
assert call_kwargs["agent_provider"] == "provider_x"
|
||||
assert call_kwargs["agent_strategy"] == "strategy_x"
|
||||
assert call_kwargs["agent_params"] == {"converted": True}
|
||||
assert call_kwargs["conversation_id"] == "conv_1"
|
||||
assert call_kwargs["app_id"] == "app_1"
|
||||
assert call_kwargs["message_id"] == "msg_1"
|
||||
assert call_kwargs["context"] is not None
|
||||
|
||||
def test_invoke_with_credentials(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
# Patch PluginInvokeContext to bypass pydantic validation
|
||||
mock_context = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginInvokeContext",
|
||||
return_value=mock_context,
|
||||
)
|
||||
|
||||
credentials = MagicMock()
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
(None, None, None),
|
||||
("conv", None, None),
|
||||
(None, "app", None),
|
||||
(None, None, "msg"),
|
||||
],
|
||||
)
|
||||
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
side_effect=ValueError("conversion failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
|
||||
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
802
api/tests/unit_tests/core/agent/test_base_agent_runner.py
Normal file
802
api/tests/unit_tests/core/agent/test_base_agent_runner.py
Normal file
@@ -0,0 +1,802 @@
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.agent.base_agent_runner as module
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
|
||||
# ==========================================================
|
||||
# Fixtures
|
||||
# ==========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mocker):
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, mock_db_session):
|
||||
r = BaseAgentRunner.__new__(BaseAgentRunner)
|
||||
r.tenant_id = "tenant"
|
||||
r.user_id = "user"
|
||||
r.agent_thought_count = 0
|
||||
r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1")
|
||||
r.app_config = mocker.MagicMock()
|
||||
r.app_config.app_id = "app1"
|
||||
r.app_config.agent = None
|
||||
r.dataset_tools = []
|
||||
r.application_generate_entity = mocker.MagicMock(invoke_from="test")
|
||||
r._current_thoughts = []
|
||||
return r
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _repack_app_generate_entity
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestRepack:
|
||||
def test_sets_empty_if_none(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = None
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == ""
|
||||
|
||||
def test_keeps_existing(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = "abc"
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == "abc"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# update_prompt_message_tool
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestUpdatePromptTool:
|
||||
def build_param(self, mocker, **kwargs):
|
||||
p = mocker.MagicMock()
|
||||
p.form = kwargs.get("form")
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
p.type = mock_type
|
||||
|
||||
p.name = kwargs.get("name", "p1")
|
||||
p.llm_description = "desc"
|
||||
p.input_schema = kwargs.get("input_schema")
|
||||
p.options = kwargs.get("options")
|
||||
p.required = kwargs.get("required", False)
|
||||
return p
|
||||
|
||||
def test_skip_non_llm(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form="NOT_LLM")
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_enum_and_required(self, runner, mocker):
|
||||
option = mocker.MagicMock(value="opt1")
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
options=[option],
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool = mocker.MagicMock()
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert "p1" in result.parameters["required"]
|
||||
|
||||
def test_skip_file_type_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
|
||||
param.type = module.ToolParameter.ToolParameterType.FILE
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_duplicate_required_not_duplicated(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["required"].count("p1") == 1
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# create_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestCreateAgentThought:
|
||||
def test_with_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=10)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"])
|
||||
assert result == "10"
|
||||
assert runner.agent_thought_count == 1
|
||||
|
||||
def test_without_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=11)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", [])
|
||||
assert result == "11"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# save_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestSaveAgentThought:
|
||||
def setup_agent(self, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;tool2"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
return agent
|
||||
|
||||
def test_not_found(self, runner, mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError):
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
def test_full_update(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mock_label = mocker.MagicMock()
|
||||
mock_label.to_dict.return_value = {"en_US": "label"}
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label)
|
||||
|
||||
usage = mocker.MagicMock(
|
||||
prompt_tokens=1,
|
||||
prompt_price_unit=Decimal("0.1"),
|
||||
prompt_unit_price=Decimal("0.1"),
|
||||
completion_tokens=2,
|
||||
completion_price_unit=Decimal("0.2"),
|
||||
completion_unit_price=Decimal("0.2"),
|
||||
total_tokens=3,
|
||||
total_price=Decimal("0.3"),
|
||||
)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1;tool2",
|
||||
{"a": 1},
|
||||
"thought",
|
||||
{"b": 2},
|
||||
{"meta": 1},
|
||||
"answer",
|
||||
["f1"],
|
||||
usage,
|
||||
)
|
||||
|
||||
assert agent.answer == "answer"
|
||||
assert agent.tokens == 3
|
||||
assert "tool1" in json.loads(agent.tool_labels_str)
|
||||
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
agent.tool = "unknown_tool"
|
||||
mock_db_session.scalar.return_value = agent
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "unknown_tool" in labels
|
||||
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
bad_obj = MagicMock()
|
||||
bad_obj.__str__.return_value = "bad"
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
bad_obj,
|
||||
None,
|
||||
bad_obj,
|
||||
bad_obj,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
{"a": 1},
|
||||
None,
|
||||
{"b": 2},
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_user_prompt
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeUserPrompt:
|
||||
def test_no_files(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[])
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_history
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeHistory:
|
||||
def test_empty(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
|
||||
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input="invalid",
|
||||
observation="invalid",
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=";", thought="thinking")
|
||||
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=json.dumps({"tool1": {"x": 1}}),
|
||||
observation=json.dumps({"tool1": "obs"}),
|
||||
thought="thinking",
|
||||
)
|
||||
|
||||
msg = mocker.MagicMock(
|
||||
id="m100",
|
||||
agent_thoughts=[thought],
|
||||
answer=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _convert_tool_to_prompt_message_tool (new coverage)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestConvertToolToPromptMessageTool:
|
||||
def test_basic_conversion(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
runtime_param = mocker.MagicMock()
|
||||
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param.name = "param1"
|
||||
runtime_param.llm_description = "desc"
|
||||
runtime_param.required = True
|
||||
runtime_param.input_schema = None
|
||||
runtime_param.options = None
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
runtime_param.type = mock_type
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert entity == tool_entity
|
||||
|
||||
def test_full_conversion_multiple_params(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
# LLM param with input_schema override
|
||||
param1 = mocker.MagicMock()
|
||||
param1.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param1.name = "p1"
|
||||
param1.llm_description = "desc"
|
||||
param1.required = True
|
||||
param1.input_schema = {"type": "integer"}
|
||||
param1.options = None
|
||||
param1.type = mocker.MagicMock()
|
||||
|
||||
# SYSTEM_FILES param should be skipped
|
||||
param2 = mocker.MagicMock()
|
||||
param2.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param2.name = "file_param"
|
||||
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert entity == tool_entity
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _init_prompt_tools additional branches
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestInitPromptToolsExtended:
|
||||
def test_agent_tool_branch(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="agent_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert "agent_tool" in tools
|
||||
|
||||
def test_exception_in_conversion(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="bad_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert tools == {}
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestAdditionalCoverage:
|
||||
def test_update_prompt_with_input_schema(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "p1"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.options = None
|
||||
param.input_schema = {"type": "number"}
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
param.type = mock_type
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"]["p1"]["type"] == "number"
|
||||
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {"tool1": {"en_US": "existing"}}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert labels["tool1"]["en_US"] == "existing"
|
||||
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
|
||||
assert agent.tool_meta_str == "meta_string"
|
||||
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
assert prompt is not None
|
||||
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"])
|
||||
mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock())
|
||||
|
||||
mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw))
|
||||
mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result is not None
|
||||
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=None, thought="thinking")
|
||||
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1;tool2",
|
||||
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
|
||||
observation=json.dumps({"tool1": "o1", "tool2": "o2"}),
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
# ================= Additional Surgical Coverage =================
|
||||
|
||||
def test_convert_tool_select_enum_branch(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = True
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert prompt_tool is not None
|
||||
|
||||
|
||||
class TestConvertDatasetRetrieverTool:
|
||||
def test_required_param_added(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
|
||||
assert prompt is not None
|
||||
|
||||
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"])
|
||||
|
||||
llm = mocker.MagicMock()
|
||||
llm.get_model_schema.return_value = mocker.MagicMock(
|
||||
features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION]
|
||||
)
|
||||
model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c")
|
||||
|
||||
app_config = mocker.MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.agent = None
|
||||
app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"})
|
||||
app_config.additional_features = mocker.MagicMock(show_retrieve_source=True)
|
||||
|
||||
app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"])
|
||||
message = mocker.MagicMock(id="msg1", conversation_id="conv1")
|
||||
|
||||
runner = BaseAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=app_generate,
|
||||
conversation=mocker.MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=mocker.MagicMock(),
|
||||
config=mocker.MagicMock(),
|
||||
queue_manager=mocker.MagicMock(),
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert runner.stream_tool_call is True
|
||||
assert runner.files == ["file1"]
|
||||
assert runner.dataset_tools == ["ds_tool"]
|
||||
assert runner.agent_thought_count == 2
|
||||
|
||||
|
||||
class TestBaseAgentRunnerCoverage:
|
||||
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = "NOT_LLM"
|
||||
param.type = mocker.MagicMock()
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert prompt_tool.parameters["properties"] == {}
|
||||
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
|
||||
dataset_tool = mocker.MagicMock()
|
||||
dataset_tool.entity.identity.name = "ds"
|
||||
runner.dataset_tools = [dataset_tool]
|
||||
|
||||
mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock())
|
||||
|
||||
tools, prompt_tools = runner._init_prompt_tools()
|
||||
|
||||
assert tools["ds"] == dataset_tool
|
||||
assert len(prompt_tools) == 1
|
||||
|
||||
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
|
||||
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
tool_input = {"a": 1}
|
||||
observation = {"b": 2}
|
||||
tool_meta = {"c": 3}
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def dumps_side_effect(value, *args, **kwargs):
|
||||
if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False:
|
||||
raise TypeError("fail")
|
||||
return real_dumps(value, *args, **kwargs)
|
||||
|
||||
mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1",
|
||||
tool_input,
|
||||
None,
|
||||
observation,
|
||||
tool_meta,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
assert isinstance(agent.tool_meta_str, str)
|
||||
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;;"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "" not in labels
|
||||
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
|
||||
system_message = module.SystemPromptMessage(content="sys")
|
||||
|
||||
result = runner.organize_agent_history([system_message])
|
||||
|
||||
assert system_message in result
|
||||
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=None,
|
||||
observation=None,
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"organize_agent_user_prompt",
|
||||
return_value=module.UserPromptMessage(content="user"),
|
||||
)
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
|
||||
assert any(isinstance(item, module.ToolPromptMessage) for item in result)
|
||||
551
api/tests/unit_tests/core/agent/test_cot_agent_runner.py
Normal file
551
api/tests/unit_tests/core/agent/test_cot_agent_runner.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
|
||||
|
||||
class DummyRunner(CotAgentRunner):
|
||||
"""Concrete implementation for testing abstract methods."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
# Minimal required defaults
|
||||
self.history_prompt_messages = []
|
||||
self.memory = None
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
return_value=[],
|
||||
)
|
||||
# Prepare required constructor dependencies for BaseAgentRunner
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock()
|
||||
application_generate_entity.model_conf.stop = []
|
||||
application_generate_entity.model_conf.provider = "openai"
|
||||
application_generate_entity.model_conf.parameters = {}
|
||||
application_generate_entity.trace_manager = None
|
||||
application_generate_entity.invoke_from = "test"
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock()
|
||||
app_config.agent.max_iteration = 1
|
||||
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.model = "test-model"
|
||||
|
||||
queue_manager = MagicMock()
|
||||
message = MagicMock()
|
||||
|
||||
runner = DummyRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=model_config,
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Patch internal methods to isolate behavior
|
||||
runner._repack_app_generate_entity = MagicMock()
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.create_agent_thought = MagicMock(return_value="thought-id")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
class TestFillInputs:
|
||||
@pytest.mark.parametrize(
|
||||
("instruction", "inputs", "expected"),
|
||||
[
|
||||
("Hello {{name}}", {"name": "John"}, "Hello John"),
|
||||
("No placeholders", {"name": "John"}, "No placeholders"),
|
||||
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
|
||||
("{{x}}", {"x": None}, "None"),
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
action_str="action1",
|
||||
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
|
||||
observation="obs1",
|
||||
)
|
||||
sp2 = AgentScratchpadUnit(
|
||||
agent_response="final",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
|
||||
observation=None,
|
||||
)
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
# Simulate final state via action name
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
action=None,
|
||||
observation="obs",
|
||||
)
|
||||
# Non-final state: provide a non-final action
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Thought:" in result
|
||||
assert "Action:" in result
|
||||
assert "Observation:" in result
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("result", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, [])
|
||||
assert response == "result"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
)
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
|
||||
usage_1 = LLMUsage.empty_usage()
|
||||
usage_1.prompt_tokens = 1
|
||||
usage_1.completion_tokens = 1
|
||||
usage_1.total_tokens = 2
|
||||
usage_1.prompt_price = 1
|
||||
usage_1.completion_price = 1
|
||||
usage_1.total_price = 2
|
||||
|
||||
usage_2 = LLMUsage.empty_usage()
|
||||
usage_2.prompt_tokens = 1
|
||||
usage_2.completion_tokens = 1
|
||||
usage_2.total_tokens = 2
|
||||
usage_2.prompt_price = 1
|
||||
usage_2.completion_price = 1
|
||||
usage_2.total_price = 2
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
handle_output = mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[
|
||||
[action],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_side_effect(chunks, usage_dict):
|
||||
call_index = handle_output.call_count
|
||||
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
|
||||
return [action] if call_index == 1 else []
|
||||
|
||||
handle_output.side_effect = _handle_side_effect
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
final_usage = results[-1].delta.usage
|
||||
assert final_usage is not None
|
||||
assert final_usage.prompt_tokens == 2
|
||||
assert final_usage.completion_tokens == 2
|
||||
assert final_usage.total_tokens == 4
|
||||
assert final_usage.prompt_price == 2
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
# First iteration → action
|
||||
# Second iteration → no action (empty list)
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[[action], []],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.app_config.agent.max_iteration = 5
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
|
||||
|
||||
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
|
||||
|
||||
runner.history_prompt_messages = [assistant, tool_msg]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
runner.history_prompt_messages = [assistant]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
|
||||
runner._init_react_state("new-query")
|
||||
|
||||
assert runner._query == "new-query"
|
||||
assert runner._agent_scratchpad == []
|
||||
assert runner._historic_prompt_messages == ["historic"]
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
|
||||
)
|
||||
|
||||
message_file_ids = []
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
|
||||
|
||||
assert response == "ok"
|
||||
assert message_file_ids == ["file1"]
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
|
||||
# Should silently continue on exception
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
|
||||
runner.history_prompt_messages = [user_message]
|
||||
|
||||
mock_transform = mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
)
|
||||
mock_transform.return_value.get_prompt.return_value = ["final"]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
return_value=mock_transform,
|
||||
)
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=["thinking"],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
# Remove invalid branch: Pydantic enforces str|dict for action_input
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "12345"
|
||||
215
api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py
Normal file
215
api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyAgentConfig,
|
||||
DummyAppConfig,
|
||||
DummyTool,
|
||||
)
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyPromptEntity as DummyPrompt,
|
||||
)
|
||||
|
||||
|
||||
class DummyFileUploadConfig:
|
||||
def __init__(self, image_config=None):
|
||||
self.image_config = image_config
|
||||
|
||||
|
||||
class DummyImageConfig:
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, file_upload_config=None):
|
||||
self.file_upload_config = file_upload_config
|
||||
|
||||
|
||||
class DummyUnit:
|
||||
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
|
||||
runner._instruction = "test_instruction"
|
||||
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
|
||||
runner._query = "user query"
|
||||
runner._agent_scratchpad = []
|
||||
runner.files = []
|
||||
runner.application_generate_entity = DummyGenerateEntity()
|
||||
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
|
||||
return runner
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_chat_agent_runner.jsonable_encoder",
|
||||
return_value=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
|
||||
result = runner._organize_system_prompt()
|
||||
|
||||
assert "test_instruction" in result.content
|
||||
assert "tool1" in result.content
|
||||
assert "tool2" in result.content
|
||||
assert "tool1, tool2" in result.content
|
||||
|
||||
def test_organize_system_prompt_missing_agent(self, runner):
|
||||
runner.app_config = DummyAppConfig(agent=None)
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
def test_organize_system_prompt_missing_prompt(self, runner):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
|
||||
def test_organize_user_query_no_files(self, runner, files):
|
||||
runner.files = files
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "query"
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
|
||||
image_config = DummyImageConfig(detail="high")
|
||||
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
|
||||
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assert "system" in result
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(final=True, agent_response="done")
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(
|
||||
final=False,
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
observation="observe",
|
||||
)
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: thinking" in combined
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
units = [
|
||||
DummyUnit(final=False, thought="t1"),
|
||||
DummyUnit(final=True, agent_response="done"),
|
||||
]
|
||||
runner._agent_scratchpad = units
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: t1" in combined
|
||||
assert "Final Answer: done" in combined
|
||||
@@ -0,0 +1,234 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, dummy_tool_factory):
|
||||
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
|
||||
|
||||
runner._instruction = "Test instruction"
|
||||
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
|
||||
runner._query = "What is Python?"
|
||||
runner._agent_scratchpad = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_completion_agent_runner.jsonable_encoder",
|
||||
side_effect=lambda tools: [{"name": t.name} for t in tools],
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_instruction_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
)
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
result = runner._organize_instruction_prompt()
|
||||
|
||||
assert "Test instruction" in result
|
||||
assert "toolA" in result
|
||||
assert "toolB" in result
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_historic_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[user_msg, assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_prompt_messages Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
|
||||
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
|
||||
content = result[0].content
|
||||
|
||||
assert "History" in content
|
||||
assert "Thought: Thinking" in content
|
||||
assert "Action: Act" in content
|
||||
assert "Observation: Obs" in content
|
||||
assert "Final Answer: Done" in content
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = None
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert "Question: What is Python?" in result[0].content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("thought", "action", "observation"),
|
||||
[
|
||||
("T", None, None),
|
||||
("T", "A", None),
|
||||
("T", None, "O"),
|
||||
],
|
||||
)
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(
|
||||
final=False,
|
||||
thought=thought,
|
||||
action_str=action,
|
||||
observation=observation,
|
||||
)
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
content = result[0].content
|
||||
|
||||
assert "Thought:" in content
|
||||
if action:
|
||||
assert "Action:" in content
|
||||
if observation:
|
||||
assert "Observation:" in content
|
||||
452
api/tests/unit_tests/core/agent/test_fc_agent_runner.py
Normal file
452
api/tests/unit_tests/core/agent/test_fc_agent_runner.py
Normal file
@@ -0,0 +1,452 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
|
||||
# ==============================
|
||||
# Dummy Helper Classes
|
||||
# ==============================
|
||||
|
||||
|
||||
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = pt
|
||||
usage.completion_tokens = ct
|
||||
usage.total_tokens = tt
|
||||
usage.prompt_price = 0
|
||||
usage.completion_price = 0
|
||||
usage.total_price = 0
|
||||
return usage
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
|
||||
self.content: str | None = content
|
||||
self.tool_calls: list[Any] = tool_calls or []
|
||||
|
||||
|
||||
class DummyDelta:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
|
||||
|
||||
|
||||
class DummyResult:
|
||||
def __init__(
|
||||
self,
|
||||
message: DummyMessage | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
prompt_messages: list[DummyMessage] | None = None,
|
||||
):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
self.prompt_messages: list[DummyMessage] = prompt_messages or []
|
||||
self.system_fingerprint: str = ""
|
||||
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Patch streaming chunk models to avoid validation on dummy message objects
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock(max_iteration=2)
|
||||
app_config.prompt_template = MagicMock(simple_prompt_template="system")
|
||||
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
|
||||
application_generate_entity.trace_manager = MagicMock()
|
||||
application_generate_entity.invoke_from = "test"
|
||||
application_generate_entity.app_config = MagicMock(app_id="app")
|
||||
application_generate_entity.file_upload_config = None
|
||||
|
||||
queue_manager = MagicMock()
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
|
||||
message = MagicMock(id="msg1")
|
||||
conversation = MagicMock(id="conv1")
|
||||
|
||||
runner = FunctionCallAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=MagicMock(),
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Manually inject required attributes normally set by BaseAgentRunner
|
||||
runner.tenant_id = "tenant"
|
||||
runner.application_generate_entity = application_generate_entity
|
||||
runner.conversation = conversation
|
||||
runner.app_config = app_config
|
||||
runner.model_config = MagicMock()
|
||||
runner.config = MagicMock()
|
||||
runner.queue_manager = queue_manager
|
||||
runner.message = message
|
||||
runner.user_id = "user"
|
||||
runner.model_instance = model_instance
|
||||
|
||||
runner.stream_tool_call = False
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.agent_callback = MagicMock()
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.create_agent_thought = MagicMock(return_value="thought1")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ==============================
|
||||
# Tool Call Checks
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
|
||||
# ==============================
|
||||
# Extract Tool Calls
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = ""
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
tool_call.function.arguments = json.dumps({"x": 2})
|
||||
|
||||
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_blocking_tool_calls(result)
|
||||
|
||||
assert calls == [("2", "block", {"x": 2})]
|
||||
|
||||
|
||||
# ==============================
|
||||
# System Message Initialization
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ==============================
|
||||
# Organize User Query
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
return_value=file_content,
|
||||
)
|
||||
|
||||
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
|
||||
runner.files = ["file1"]
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Clear User Prompt Images
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
|
||||
image = MagicMock()
|
||||
image.type = "image"
|
||||
image.data = "img"
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.__class__.__name__ = "UserPromptMessage"
|
||||
user_msg.content = [text, image]
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
|
||||
user_msg = UserPromptMessage(content=[text, image, document])
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
|
||||
assert result[0].content == "hello\n[image]\n[file]"
|
||||
|
||||
|
||||
# ==============================
|
||||
# Run Method Tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def flaky_dumps(obj, *args, **kwargs):
|
||||
if kwargs.get("ensure_ascii") is False:
|
||||
return real_dumps(obj, *args, **kwargs)
|
||||
raise TypeError("boom")
|
||||
|
||||
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "missing"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
tool_instance = MagicMock()
|
||||
prompt_tool = MagicMock()
|
||||
prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
|
||||
|
||||
tool_invoke_meta = MagicMock()
|
||||
tool_invoke_meta.to_dict.return_value = {"ok": True}
|
||||
mocker.patch(
|
||||
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], tool_invoke_meta),
|
||||
)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
assert any(
|
||||
isinstance(call.args[0], QueueMessageFileEvent)
|
||||
and call.args[0].message_file_id == "file1"
|
||||
and call.args[1] == PublishFrom.APPLICATION_MANAGER
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query"))
|
||||
324
api/tests/unit_tests/core/agent/test_plugin_entities.py
Normal file
324
api/tests/unit_tests/core/agent/test_plugin_entities.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Unit tests for core.agent.plugin_entities.
|
||||
|
||||
Covers entities such as AgentFeature, AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on
|
||||
Pydantic ValidationError behavior and pytest fixtures for validation and
|
||||
mocking; ensure entity invariants and validation rules remain stable.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agent.plugin_entities import (
|
||||
AgentFeature,
|
||||
AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity,
|
||||
AgentStrategyIdentity,
|
||||
AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity,
|
||||
AgentStrategyProviderIdentity,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
|
||||
|
||||
# =========================================================
|
||||
# Fixtures
|
||||
# =========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyIdentity)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameterType Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameterType:
|
||||
@pytest.mark.parametrize(
|
||||
"enum_member",
|
||||
list(AgentStrategyParameter.AgentStrategyParameterType),
|
||||
)
|
||||
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
return_value="normalized",
|
||||
)
|
||||
|
||||
result = enum_member.as_normal_type()
|
||||
|
||||
mock_func.assert_called_once_with(enum_member)
|
||||
assert result == "normalized"
|
||||
|
||||
def test_as_normal_type_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
enum_member.as_normal_type()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("enum_member", "value"),
|
||||
[
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, None),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
|
||||
],
|
||||
)
|
||||
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
return_value="casted",
|
||||
)
|
||||
|
||||
result = enum_member.cast_value(value)
|
||||
|
||||
mock_func.assert_called_once_with(enum_member, value)
|
||||
assert result == "casted"
|
||||
|
||||
def test_cast_value_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
side_effect=ValueError("invalid"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
enum_member.cast_value("bad")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameter Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameter:
|
||||
def test_valid_creation_minimal(self) -> None:
|
||||
# bypass base PluginParameter required fields using model_construct
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=None,
|
||||
)
|
||||
assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
assert param.help is None
|
||||
|
||||
def test_valid_creation_with_help(self) -> None:
|
||||
help_obj = I18nObject(en_US="test")
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=help_obj,
|
||||
)
|
||||
assert param.help == help_obj
|
||||
|
||||
@pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}])
|
||||
def test_invalid_type_raises_validation_error(self, invalid_type) -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y"))
|
||||
|
||||
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
|
||||
|
||||
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
return_value="frontend",
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
result = param.init_frontend_parameter("value")
|
||||
|
||||
mock_func.assert_called_once_with(param, param.type, "value")
|
||||
assert result == "frontend"
|
||||
|
||||
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
side_effect=RuntimeError("error"),
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
param.init_frontend_parameter("value")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyProviderEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyProviderEntity:
|
||||
def test_creation_with_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="plugin-123",
|
||||
)
|
||||
assert entity.plugin_id == "plugin-123"
|
||||
|
||||
def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="",
|
||||
)
|
||||
assert entity.plugin_id == ""
|
||||
|
||||
def test_creation_without_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(identity=mock_provider_identity)
|
||||
assert entity.plugin_id is None
|
||||
|
||||
def test_invalid_identity_raises(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyProviderEntity(identity="invalid")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyEntity:
|
||||
def test_parameters_default_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_none_converted_to_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=None,
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_preserved(self, mock_identity) -> None:
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[param],
|
||||
)
|
||||
assert entity.parameters == [param]
|
||||
|
||||
def test_invalid_parameters_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters="invalid",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"features",
|
||||
[
|
||||
None,
|
||||
[],
|
||||
[AgentFeature.HISTORY_MESSAGES],
|
||||
],
|
||||
)
|
||||
def test_features_valid(self, mock_identity, features) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features=features,
|
||||
)
|
||||
assert entity.features == features
|
||||
|
||||
def test_invalid_features_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features="invalid",
|
||||
)
|
||||
|
||||
def test_output_schema_and_meta_version(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
output_schema={"type": "object"},
|
||||
meta_version="v1",
|
||||
)
|
||||
assert entity.output_schema == {"type": "object"}
|
||||
assert entity.meta_version == "v1"
|
||||
|
||||
def test_missing_required_fields_raise(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(identity=mock_identity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentProviderEntityWithPlugin Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentProviderEntityWithPlugin:
|
||||
def test_default_strategies_empty(self, mock_provider_identity) -> None:
|
||||
entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity)
|
||||
assert entity.strategies == []
|
||||
|
||||
def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None:
|
||||
strategy = AgentStrategyEntity.model_construct(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
entity = AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies=[strategy],
|
||||
)
|
||||
assert entity.strategies == [strategy]
|
||||
|
||||
def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies="invalid",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# Inheritance Smoke Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestInheritanceBehavior:
|
||||
def test_agent_strategy_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyIdentity, ToolIdentity)
|
||||
|
||||
def test_agent_strategy_provider_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)
|
||||
0
api/tests/unit_tests/core/app/apps/__init__.py
Normal file
0
api/tests/unit_tests/core/app/apps/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAdvancedChatAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow)
|
||||
|
||||
assert app_config.workflow_id == "wf-1"
|
||||
assert app_config.app_mode == AppMode.ADVANCED_CHAT
|
||||
|
||||
def test_config_validate_filters_keys(self):
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = kwargs.get("config") if kwargs else args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 7),
|
||||
),
|
||||
):
|
||||
filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={})
|
||||
|
||||
assert filtered["file_upload"] == 1
|
||||
assert filtered["opening_statement"] == 2
|
||||
assert filtered["suggested_questions_after_answer"] == 3
|
||||
assert filtered["speech_to_text"] == 4
|
||||
assert filtered["text_to_speech"] == 5
|
||||
assert filtered["retriever_resource"] == 6
|
||||
assert filtered["sensitive_word_avoidance"] == 7
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,96 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateResponseConverter:
|
||||
def test_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_stream_simple_response_includes_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_start,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_finish,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert converted[0] == "ping"
|
||||
assert converted[1]["event"] == "node_started"
|
||||
assert converted[2]["event"] == "node_finished"
|
||||
assert converted[3]["event"] == "error"
|
||||
@@ -0,0 +1,600 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.enums import MessageStatus
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _make_pipeline():
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
|
||||
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_ensure_workflow_initialized_raises(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow run not initialized"):
|
||||
pipeline._ensure_workflow_initialized()
|
||||
|
||||
def test_to_blocking_response_returns_message_end(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.answer = "done"
|
||||
|
||||
def _gen():
|
||||
yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"})
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.answer == "done"
|
||||
assert response.data.metadata == {"k": "v"}
|
||||
|
||||
def test_handle_text_chunk_event_updates_state(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager = SimpleNamespace(
|
||||
message_to_stream_response=lambda **kwargs: MessageEndStreamResponse(
|
||||
task_id="task", id="message-id", metadata={}
|
||||
)
|
||||
)
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=None)
|
||||
|
||||
responses = list(pipeline._handle_text_chunk_event(event))
|
||||
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert responses
|
||||
|
||||
def test_listen_audio_msg_returns_audio_stream(self):
|
||||
pipeline = _make_pipeline()
|
||||
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
|
||||
|
||||
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
|
||||
|
||||
assert isinstance(response, MessageAudioStreamResponse)
|
||||
|
||||
def test_handle_ping_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
|
||||
|
||||
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
|
||||
|
||||
assert isinstance(responses[0], PingStreamResponse)
|
||||
|
||||
def test_handle_error_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
pipeline._database_session = _fake_session
|
||||
|
||||
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
|
||||
|
||||
assert isinstance(responses[0], ValueError)
|
||||
|
||||
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
|
||||
|
||||
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
|
||||
|
||||
assert pipeline._workflow_run_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
def test_message_end_to_stream_response_strips_annotation_reply(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id="ann",
|
||||
account=AnnotationReplyAccount(id="acc", name="acc"),
|
||||
)
|
||||
|
||||
response = pipeline._message_end_to_stream_response()
|
||||
|
||||
assert "annotation_reply" not in response.metadata
|
||||
|
||||
def test_handle_output_moderation_chunk_publishes_stop(self):
|
||||
pipeline = _make_pipeline()
|
||||
events: list[object] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return True
|
||||
|
||||
def get_final_output(self):
|
||||
return "final"
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
pipeline._base_task_pipeline.queue_manager = SimpleNamespace(
|
||||
publish=lambda event, pub_from: events.append(event)
|
||||
)
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("ignored")
|
||||
|
||||
assert result is True
|
||||
assert pipeline._task_state.answer == "final"
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in events)
|
||||
assert any(isinstance(event, QueueStopEvent) for event in events)
|
||||
|
||||
def test_handle_node_succeeded_event_records_files(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [
|
||||
{"type": "file", "transfer_method": "local"}
|
||||
]
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
event = SimpleNamespace(
|
||||
node_type=NodeType.ANSWER,
|
||||
outputs={"k": "v"},
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
assert responses == ["done"]
|
||||
assert pipeline._recorded_files
|
||||
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
|
||||
"iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
|
||||
"iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
|
||||
|
||||
iter_start = QueueIterationStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_next = QueueIterationNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_done = QueueIterationCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_start = QueueLoopStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_next = QueueLoopNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_done = QueueLoopCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"]
|
||||
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"]
|
||||
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"]
|
||||
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"]
|
||||
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
|
||||
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
|
||||
|
||||
def test_workflow_finish_handlers(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"]
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: None
|
||||
pipeline._save_message = lambda **kwargs: None
|
||||
pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id")
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace(scalar=lambda *args, **kwargs: None)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={})))
|
||||
assert len(succeeded_responses) == 2
|
||||
assert isinstance(succeeded_responses[0], MessageEndStreamResponse)
|
||||
assert succeeded_responses[1] == "finish"
|
||||
|
||||
partial_success_responses = list(
|
||||
pipeline._handle_workflow_partial_success_event(
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
)
|
||||
)
|
||||
assert len(partial_success_responses) == 2
|
||||
assert isinstance(partial_success_responses[0], MessageEndStreamResponse)
|
||||
assert partial_success_responses[1] == "finish"
|
||||
assert (
|
||||
list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0]
|
||||
== "finish"
|
||||
)
|
||||
assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [
|
||||
"pause"
|
||||
]
|
||||
|
||||
def test_node_failure_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
failed_event = QueueNodeFailedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
exc_event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"]
|
||||
assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"]
|
||||
|
||||
def test_handle_text_chunk_event_tracks_streaming_metrics(self):
|
||||
pipeline = _make_pipeline()
|
||||
published: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
published.append(message)
|
||||
|
||||
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
|
||||
)
|
||||
|
||||
assert responses == ["chunk"]
|
||||
assert pipeline._task_state.is_streaming_response is True
|
||||
assert pipeline._task_state.first_token_time is not None
|
||||
assert pipeline._task_state.last_token_time is not None
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert published == [queue_message]
|
||||
|
||||
def test_handle_output_moderation_chunk_appends_token(self):
|
||||
pipeline = _make_pipeline()
|
||||
seen: list[str] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return False
|
||||
|
||||
def append_new_token(self, text):
|
||||
seen.append(text)
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("token")
|
||||
|
||||
assert result is False
|
||||
assert seen == ["token"]
|
||||
|
||||
def test_handle_retriever_and_annotation_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"retriever": 0, "annotation": 0}
|
||||
|
||||
def _hit_retriever(event):
|
||||
calls["retriever"] += 1
|
||||
|
||||
def _hit_annotation(event):
|
||||
calls["annotation"] += 1
|
||||
|
||||
pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever
|
||||
pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation
|
||||
|
||||
retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[])
|
||||
annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann")
|
||||
|
||||
assert list(pipeline._handle_retriever_resources_event(retriever_event)) == []
|
||||
assert list(pipeline._handle_annotation_reply_event(annotation_event)) == []
|
||||
assert calls == {"retriever": 1, "annotation": 1}
|
||||
|
||||
def test_handle_message_replace_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
|
||||
event = QueueMessageReplaceEvent(
|
||||
text="new",
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_message_replace_event(event)) == ["replace"]
|
||||
|
||||
def test_handle_human_input_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
persisted: list[str] = []
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved")
|
||||
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
|
||||
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
|
||||
|
||||
filled_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
action_id="action",
|
||||
action_text="action",
|
||||
)
|
||||
timeout_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
expiration_time=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
|
||||
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
|
||||
assert persisted == ["saved"]
|
||||
|
||||
def test_save_message_strips_markdown_and_sets_usage(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._recorded_files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote",
|
||||
"remote_url": "http://example.com/file.png",
|
||||
"related_id": "file-id",
|
||||
}
|
||||
]
|
||||
pipeline._task_state.answer = " hello"
|
||||
pipeline._task_state.is_streaming_response = True
|
||||
pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1
|
||||
pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
status=MessageStatus.PAUSED,
|
||||
answer="",
|
||||
updated_at=None,
|
||||
provider_response_latency=None,
|
||||
message_tokens=None,
|
||||
message_unit_price=None,
|
||||
message_price_unit=None,
|
||||
answer_tokens=None,
|
||||
answer_unit_price=None,
|
||||
answer_price_unit=None,
|
||||
total_price=None,
|
||||
currency=None,
|
||||
message_metadata=None,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_account_id=None,
|
||||
from_end_user_id="end-user",
|
||||
)
|
||||
|
||||
class _Session:
|
||||
def scalar(self, *args, **kwargs):
|
||||
return message
|
||||
|
||||
def add_all(self, items):
|
||||
self.items = items
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state)
|
||||
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
assert message.answer == "hello"
|
||||
assert message.message_metadata
|
||||
|
||||
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)))
|
||||
|
||||
assert responses == ["end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent()))
|
||||
|
||||
assert responses == ["replace", "end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_dispatch_event_handles_node_exception(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
|
||||
pipeline._save_output_for_event = lambda *args, **kwargs: None
|
||||
|
||||
event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._dispatch_event(event)) == ["failed"]
|
||||
@@ -0,0 +1,302 @@
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.agent_chat.app_config_manager import (
|
||||
AgentChatAppConfigManager,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerGetAppConfig:
|
||||
def test_get_app_config_override_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"ignored": True}
|
||||
|
||||
override_config = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.variables == "variables"
|
||||
assert result.external_data_variables == "external"
|
||||
|
||||
def test_get_app_config_conversation_specific(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
conversation = mocker.MagicMock()
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == app_model_config.to_dict.return_value
|
||||
assert result.app_model_config_from.value == "conversation-specific-config"
|
||||
|
||||
def test_get_app_config_latest_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from.value == "app-latest-config"
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerConfigValidate:
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {},
|
||||
"user_input_form": {},
|
||||
"file_upload": {},
|
||||
"prompt_template": {},
|
||||
"agent_mode": {},
|
||||
"opening_statement": {},
|
||||
"suggested_questions_after_answer": {},
|
||||
"speech_to_text": {},
|
||||
"text_to_speech": {},
|
||||
"retriever_resource": {},
|
||||
"dataset": {},
|
||||
"moderation": {},
|
||||
"extra": "value",
|
||||
}
|
||||
|
||||
def return_with_key(key):
|
||||
return config, [key]
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("model"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("file_upload"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda app_mode, cfg: return_with_key("prompt_template"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
AgentChatAppConfigManager,
|
||||
"validate_agent_mode_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("opening_statement"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("speech_to_text"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("text_to_speech"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("retriever_resource"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("moderation"),
|
||||
)
|
||||
|
||||
filtered = AgentChatAppConfigManager.config_validate("tenant", config)
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"prompt_template",
|
||||
"agent_mode",
|
||||
"opening_statement",
|
||||
"suggested_questions_after_answer",
|
||||
"speech_to_text",
|
||||
"text_to_speech",
|
||||
"retriever_resource",
|
||||
"dataset",
|
||||
"moderation",
|
||||
}
|
||||
assert "extra" not in filtered
|
||||
|
||||
|
||||
class TestValidateAgentModeAndSetDefaults:
|
||||
def test_defaults_when_missing(self):
|
||||
config = {}
|
||||
updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert "agent_mode" in updated
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
assert keys == ["agent_mode"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_mode",
|
||||
["invalid", 123],
|
||||
)
|
||||
def test_agent_mode_type_validation(self, agent_mode):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode})
|
||||
|
||||
def test_agent_mode_empty_list_defaults(self):
|
||||
config = {"agent_mode": []}
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
|
||||
def test_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}})
|
||||
|
||||
def test_strategy_must_be_valid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}}
|
||||
)
|
||||
|
||||
def test_tools_must_be_list(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_requires_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_must_be_uuid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_not_exists(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=False,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"])
|
||||
def test_new_style_tool_requires_fields(self, missing_key):
|
||||
tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"}
|
||||
tool.pop(missing_key, None)
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
|
||||
)
|
||||
|
||||
def test_valid_old_and_new_style_tools(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=True,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": PlanningStrategy.ROUTER.value,
|
||||
"tools": [
|
||||
{"dataset": {"id": dataset_id}},
|
||||
{
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "p1",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"][1]["enabled"] is False
|
||||
@@ -0,0 +1,296 @@
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id):
|
||||
self.id = user_id
|
||||
self.session_id = f"session-{user_id}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = AgentChatAppGenerator()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.current_app",
|
||||
new=mocker.MagicMock(_get_current_object=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx")
|
||||
return gen
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorGenerate:
|
||||
def test_generate_rejects_blocking_mode(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
|
||||
|
||||
def test_generate_requires_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
|
||||
|
||||
def test_generate_rejects_non_string_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": 123, "inputs": {}},
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
def test_generate_override_requires_debugger(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_success_with_debugger_override(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate",
|
||||
return_value={"validated": True},
|
||||
)
|
||||
app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[])
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ConversationService.get_conversation",
|
||||
return_value=mocker.MagicMock(id="conv"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=queue_manager,
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {
|
||||
"query": "hello",
|
||||
"inputs": {"name": "world"},
|
||||
"conversation_id": "conv",
|
||||
"model_config": {"model": {"provider": "p"}},
|
||||
"files": [{"id": "f1"}],
|
||||
}
|
||||
|
||||
result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
thread_obj.start.assert_called_once()
|
||||
|
||||
def test_generate_without_file_config(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {"query": "hello", "inputs": {"name": "world"}}
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorWorker:
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_context(self, mocker):
|
||||
@contextlib.contextmanager
|
||||
def ctx_manager(*args, **kwargs):
|
||||
yield
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
|
||||
|
||||
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = GenerateTaskStoppedError()
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
InvokeAuthorizationError("bad"),
|
||||
ValidationError.from_exception_data("TestModel", []),
|
||||
ValueError("bad"),
|
||||
Exception("bad"),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_publishes_errors(self, generator, mocker, error):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = error
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called
|
||||
|
||||
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = ValueError("bad")
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True))
|
||||
logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
logger.exception.assert_called_once()
|
||||
@@ -0,0 +1,413 @@
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return AgentChatAppRunner()
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_moderation_error_direct_output(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
user_id="user",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
annotation = mocker.MagicMock(id="anno", content="answer")
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation)
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
queue_manager.publish.assert_called_once()
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=True)
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_model_schema_missing(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = None
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_runner"),
|
||||
[
|
||||
(LLMMode.CHAT, "CotChatAgentRunner"),
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: mode}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = [ModelFeature.TOOL_CALL]
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_conversation_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_message_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
@@ -0,0 +1,162 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterBlocking:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"a": 1},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"a": 1}
|
||||
|
||||
def test_convert_blocking_simple_response_with_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
|
||||
def test_convert_blocking_simple_response_with_non_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse.model_construct(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterStream:
|
||||
def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _gen():
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=2,
|
||||
stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=3,
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="m1",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
"extra": "ignored",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=4,
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")),
|
||||
)
|
||||
|
||||
return _gen()
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
assert items[1]["event"] == "message"
|
||||
assert "answer" in items[1]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert items[3]["event"] == "error"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
# Assert the message event structure and content at items[1]
|
||||
assert items[1]["event"] == "message"
|
||||
assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert "metadata" in items[2]
|
||||
metadata = items[2]["metadata"]
|
||||
assert "annotation_reply" not in metadata
|
||||
assert "usage" not in metadata
|
||||
assert metadata["retriever_resources"] == [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
assert items[3]["event"] == "error"
|
||||
0
api/tests/unit_tests/core/app/apps/chat/__init__.py
Normal file
0
api/tests/unit_tests/core/app/apps/chat/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestChatAppConfigManager:
|
||||
def test_get_app_config_uses_override_dict(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
|
||||
app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"})
|
||||
override = {"model": "override"}
|
||||
|
||||
model_entity = ModelConfigEntity(provider="p", model="m")
|
||||
prompt_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hi",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])),
|
||||
):
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override,
|
||||
)
|
||||
|
||||
assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert app_config.app_model_config_dict == override
|
||||
assert app_config.app_mode == AppMode.CHAT
|
||||
|
||||
def test_config_validate_filters_related_keys(self):
|
||||
config = {"extra": 1}
|
||||
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("model", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("inputs", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("prompt", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("dataset", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 7),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 8),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 9),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 10),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 11),
|
||||
),
|
||||
):
|
||||
filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config)
|
||||
|
||||
assert filtered["model"] == 1
|
||||
assert filtered["inputs"] == 2
|
||||
assert filtered["file_upload"] == 3
|
||||
assert filtered["prompt"] == 4
|
||||
assert filtered["dataset"] == 5
|
||||
assert filtered["opening_statement"] == 6
|
||||
assert filtered["suggested_questions_after_answer"] == 7
|
||||
assert filtered["speech_to_text"] == 8
|
||||
assert filtered["text_to_speech"] == 9
|
||||
assert filtered["retriever_resource"] == 10
|
||||
assert filtered["sensitive_word_avoidance"] == 11
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user