feat(api): implement tool argument validation in Function Call strategy

This commit is contained in:
Novice
2026-02-14 14:28:34 +08:00
parent 4b6f1861ef
commit 292bfceae6
2 changed files with 102 additions and 9 deletions

View File

@@ -451,6 +451,37 @@ class AgentPattern(ABC):
return response_content, tool_files, None
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
"""Validate tool arguments against the tool's required parameters.
Checks that all required LLM-facing parameters are present and non-empty
before actual execution, preventing wasted tool invocations when the model
generates calls with missing arguments (e.g. empty ``{}``).
Returns:
Error message if validation fails, None if all required parameters are satisfied.
"""
prompt_tool = tool_instance.to_prompt_message_tool()
required_params: list[str] = prompt_tool.parameters.get("required", [])
if not required_params:
return None
missing = [
p for p in required_params
if p not in tool_args
or tool_args[p] is None
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
]
if not missing:
return None
return (
f"Missing required parameter(s): {', '.join(missing)}. "
f"Please provide all required parameters before calling this tool."
)
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool instance by its name."""
for tool in self.tools:

View File

@@ -1,6 +1,13 @@
"""Function Call strategy implementation."""
"""Function Call strategy implementation.
Implements the Function Call agent pattern where the LLM uses native tool-calling
capability to invoke tools. Includes pre-execution parameter validation that
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
and avoids counting purely-invalid rounds against the iteration budget.
"""
import json
import logging
from collections.abc import Generator
from typing import Any, Union
@@ -20,6 +27,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
from .base import AgentPattern
logger = logging.getLogger(__name__)
class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
@@ -41,6 +50,11 @@ class FunctionCallStrategy(AgentPattern):
final_text: str = ""
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
# Consecutive rounds where ALL tool calls failed parameter validation.
# When this happens the round is "free" (iteration_step not incremented)
# up to a safety cap to prevent infinite loops.
consecutive_validation_failures: int = 0
max_validation_retries: int = 3
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
@@ -100,16 +114,21 @@ class FunctionCallStrategy(AgentPattern):
# Process tool calls
tool_outputs: dict[str, str] = {}
all_validation_errors: bool = True
if tool_calls:
function_call_state = True
# Execute tools
# Execute tools (with pre-execution parameter validation)
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if not is_validation_error:
all_validation_errors = False
else:
all_validation_errors = False
yield self._finish_log(
round_log,
data={
@@ -123,7 +142,27 @@ class FunctionCallStrategy(AgentPattern):
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Skip iteration counter when every tool call in this round failed validation,
# giving the model a free retry — but cap retries to prevent infinite loops.
if tool_calls and all_validation_errors:
consecutive_validation_failures += 1
if consecutive_validation_failures >= max_validation_retries:
logger.warning(
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
consecutive_validation_failures,
)
iteration_step += 1
consecutive_validation_failures = 0
else:
logger.info(
"All tool calls failed validation (attempt %d/%d), not counting iteration",
consecutive_validation_failures,
max_validation_retries,
)
else:
consecutive_validation_failures = 0
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
@@ -225,8 +264,18 @@ class FunctionCallStrategy(AgentPattern):
tool_call_id: str,
messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
"""Handle a single tool call and return response with files and meta."""
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
"""Handle a single tool call and return response with files, meta, and validation status.
Validates required parameters before execution. When validation fails the tool
is never invoked — a synthetic error is fed back to the model so it can self-correct
without consuming a real iteration.
Returns:
(response_content, tool_files, tool_invoke_meta, is_validation_error).
``is_validation_error`` is True when the call was rejected due to missing
required parameters, allowing the caller to skip the iteration counter.
"""
# Find tool
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
@@ -250,6 +299,19 @@ class FunctionCallStrategy(AgentPattern):
)
yield tool_call_log
# Validate required parameters before execution to avoid wasted invocations
validation_error = self._validate_tool_args(tool_instance, tool_args)
if validation_error:
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = validation_error
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
yield tool_call_log
messages.append(
ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name)
)
return validation_error, [], None, True
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
@@ -272,7 +334,7 @@ class FunctionCallStrategy(AgentPattern):
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta
return response_content, tool_files, tool_invoke_meta, False
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
@@ -293,4 +355,4 @@ class FunctionCallStrategy(AgentPattern):
name=tool_name,
)
)
return error_content, [], None
return error_content, [], None, False