mirror of
https://github.com/langgenius/dify.git
synced 2026-02-24 18:05:11 +00:00
feat(api): implement tool argument validation in Function Call strategy
This commit is contained in:
@@ -451,6 +451,37 @@ class AgentPattern(ABC):
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
|
||||
"""Validate tool arguments against the tool's required parameters.
|
||||
|
||||
Checks that all required LLM-facing parameters are present and non-empty
|
||||
before actual execution, preventing wasted tool invocations when the model
|
||||
generates calls with missing arguments (e.g. empty ``{}``).
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if all required parameters are satisfied.
|
||||
"""
|
||||
prompt_tool = tool_instance.to_prompt_message_tool()
|
||||
required_params: list[str] = prompt_tool.parameters.get("required", [])
|
||||
|
||||
if not required_params:
|
||||
return None
|
||||
|
||||
missing = [
|
||||
p for p in required_params
|
||||
if p not in tool_args
|
||||
or tool_args[p] is None
|
||||
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
|
||||
]
|
||||
|
||||
if not missing:
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Missing required parameter(s): {', '.join(missing)}. "
|
||||
f"Please provide all required parameters before calling this tool."
|
||||
)
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
"""Function Call strategy implementation."""
|
||||
"""Function Call strategy implementation.
|
||||
|
||||
Implements the Function Call agent pattern where the LLM uses native tool-calling
|
||||
capability to invoke tools. Includes pre-execution parameter validation that
|
||||
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
|
||||
and avoids counting purely-invalid rounds against the iteration budget.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -20,6 +27,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
@@ -41,6 +50,11 @@ class FunctionCallStrategy(AgentPattern):
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
# Consecutive rounds where ALL tool calls failed parameter validation.
|
||||
# When this happens the round is "free" (iteration_step not incremented)
|
||||
# up to a safety cap to prevent infinite loops.
|
||||
consecutive_validation_failures: int = 0
|
||||
max_validation_retries: int = 3
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
@@ -100,16 +114,21 @@ class FunctionCallStrategy(AgentPattern):
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
all_validation_errors: bool = True
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
# Execute tools (with pre-execution parameter validation)
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
if not is_validation_error:
|
||||
all_validation_errors = False
|
||||
else:
|
||||
all_validation_errors = False
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
@@ -123,7 +142,27 @@ class FunctionCallStrategy(AgentPattern):
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Skip iteration counter when every tool call in this round failed validation,
|
||||
# giving the model a free retry — but cap retries to prevent infinite loops.
|
||||
if tool_calls and all_validation_errors:
|
||||
consecutive_validation_failures += 1
|
||||
if consecutive_validation_failures >= max_validation_retries:
|
||||
logger.warning(
|
||||
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
|
||||
consecutive_validation_failures,
|
||||
)
|
||||
iteration_step += 1
|
||||
consecutive_validation_failures = 0
|
||||
else:
|
||||
logger.info(
|
||||
"All tool calls failed validation (attempt %d/%d), not counting iteration",
|
||||
consecutive_validation_failures,
|
||||
max_validation_retries,
|
||||
)
|
||||
else:
|
||||
consecutive_validation_failures = 0
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
@@ -225,8 +264,18 @@ class FunctionCallStrategy(AgentPattern):
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
|
||||
"""Handle a single tool call and return response with files and meta."""
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
|
||||
"""Handle a single tool call and return response with files, meta, and validation status.
|
||||
|
||||
Validates required parameters before execution. When validation fails the tool
|
||||
is never invoked — a synthetic error is fed back to the model so it can self-correct
|
||||
without consuming a real iteration.
|
||||
|
||||
Returns:
|
||||
(response_content, tool_files, tool_invoke_meta, is_validation_error).
|
||||
``is_validation_error`` is True when the call was rejected due to missing
|
||||
required parameters, allowing the caller to skip the iteration counter.
|
||||
"""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
@@ -250,6 +299,19 @@ class FunctionCallStrategy(AgentPattern):
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Validate required parameters before execution to avoid wasted invocations
|
||||
validation_error = self._validate_tool_args(tool_instance, tool_args)
|
||||
if validation_error:
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = validation_error
|
||||
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
|
||||
yield tool_call_log
|
||||
|
||||
messages.append(
|
||||
ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name)
|
||||
)
|
||||
return validation_error, [], None, True
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
@@ -272,7 +334,7 @@ class FunctionCallStrategy(AgentPattern):
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
return response_content, tool_files, tool_invoke_meta, False
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
@@ -293,4 +355,4 @@ class FunctionCallStrategy(AgentPattern):
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
return error_content, [], None, False
|
||||
|
||||
Reference in New Issue
Block a user