From 292bfceae61dbd2fcab0ce3c5009389537622b7c Mon Sep 17 00:00:00 2001 From: Novice Date: Sat, 14 Feb 2026 14:28:34 +0800 Subject: [PATCH] feat(api): implement tool argument validation in Function Call strategy --- api/core/agent/patterns/base.py | 31 +++++++++ api/core/agent/patterns/function_call.py | 80 +++++++++++++++++++++--- 2 files changed, 102 insertions(+), 9 deletions(-) diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py index 03c349a475..4e6fa5284d 100644 --- a/api/core/agent/patterns/base.py +++ b/api/core/agent/patterns/base.py @@ -451,6 +451,37 @@ class AgentPattern(ABC): return response_content, tool_files, None + def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None: + """Validate tool arguments against the tool's required parameters. + + Checks that all required LLM-facing parameters are present and non-empty + before actual execution, preventing wasted tool invocations when the model + generates calls with missing arguments (e.g. empty ``{}``). + + Returns: + Error message if validation fails, None if all required parameters are satisfied. + """ + prompt_tool = tool_instance.to_prompt_message_tool() + required_params: list[str] = prompt_tool.parameters.get("required", []) + + if not required_params: + return None + + missing = [ + p for p in required_params + if p not in tool_args + or tool_args[p] is None + or (isinstance(tool_args[p], str) and not tool_args[p].strip()) + ] + + if not missing: + return None + + return ( + f"Missing required parameter(s): {', '.join(missing)}. " + f"Please provide all required parameters before calling this tool." + ) + def _find_tool_by_name(self, tool_name: str) -> Tool | None: """Find a tool instance by its name.""" for tool in self.tools: diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index 9c87c186d8..d9f043bbe5 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -1,6 +1,13 @@ -"""Function Call strategy implementation.""" +"""Function Call strategy implementation. + +Implements the Function Call agent pattern where the LLM uses native tool-calling +capability to invoke tools. Includes pre-execution parameter validation that +intercepts invalid calls (e.g. empty arguments) before they reach tool backends, +and avoids counting purely-invalid rounds against the iteration budget. +""" import json +import logging from collections.abc import Generator from typing import Any, Union @@ -20,6 +27,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta from .base import AgentPattern +logger = logging.getLogger(__name__) + class FunctionCallStrategy(AgentPattern): """Function Call strategy using model's native tool calling capability.""" @@ -41,6 +50,11 @@ class FunctionCallStrategy(AgentPattern): final_text: str = "" finish_reason: str | None = None output_files: list[File] = [] # Track files produced by tools + # Consecutive rounds where ALL tool calls failed parameter validation. + # When this happens the round is "free" (iteration_step not incremented) + # up to a safety cap to prevent infinite loops. + consecutive_validation_failures: int = 0 + max_validation_retries: int = 3 while function_call_state and iteration_step <= max_iterations: function_call_state = False @@ -100,16 +114,21 @@ class FunctionCallStrategy(AgentPattern): # Process tool calls tool_outputs: dict[str, str] = {} + all_validation_errors: bool = True if tool_calls: function_call_state = True - # Execute tools + # Execute tools (with pre-execution parameter validation) for tool_call_id, tool_name, tool_args in tool_calls: - tool_response, tool_files, _ = yield from self._handle_tool_call( + tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call( tool_name, tool_args, tool_call_id, messages, round_log ) tool_outputs[tool_name] = tool_response - # Track files produced by tools output_files.extend(tool_files) + if not is_validation_error: + all_validation_errors = False + else: + all_validation_errors = False + yield self._finish_log( round_log, data={ @@ -123,7 +142,27 @@ class FunctionCallStrategy(AgentPattern): }, usage=round_usage.get("usage"), ) - iteration_step += 1 + + # Skip iteration counter when every tool call in this round failed validation, + # giving the model a free retry — but cap retries to prevent infinite loops. + if tool_calls and all_validation_errors: + consecutive_validation_failures += 1 + if consecutive_validation_failures >= max_validation_retries: + logger.warning( + "Agent hit %d consecutive validation-only rounds, forcing iteration increment", + consecutive_validation_failures, + ) + iteration_step += 1 + consecutive_validation_failures = 0 + else: + logger.info( + "All tool calls failed validation (attempt %d/%d), not counting iteration", + consecutive_validation_failures, + max_validation_retries, + ) + else: + consecutive_validation_failures = 0 + iteration_step += 1 # Return final result from core.agent.entities import AgentResult @@ -225,8 +264,18 @@ class FunctionCallStrategy(AgentPattern): tool_call_id: str, messages: list[PromptMessage], round_log: AgentLog, - ) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]: - """Handle a single tool call and return response with files and meta.""" + ) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]: + """Handle a single tool call and return response with files, meta, and validation status. + + Validates required parameters before execution. When validation fails the tool + is never invoked — a synthetic error is fed back to the model so it can self-correct + without consuming a real iteration. + + Returns: + (response_content, tool_files, tool_invoke_meta, is_validation_error). + ``is_validation_error`` is True when the call was rejected due to missing + required parameters, allowing the caller to skip the iteration counter. + """ # Find tool tool_instance = self._find_tool_by_name(tool_name) if not tool_instance: @@ -250,6 +299,19 @@ class FunctionCallStrategy(AgentPattern): ) yield tool_call_log + # Validate required parameters before execution to avoid wasted invocations + validation_error = self._validate_tool_args(tool_instance, tool_args) + if validation_error: + tool_call_log.status = AgentLog.LogStatus.ERROR + tool_call_log.error = validation_error + tool_call_log.data = {**tool_call_log.data, "error": validation_error} + yield tool_call_log + + messages.append( + ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name) + ) + return validation_error, [], None, True + # Invoke tool using base class method with error handling try: response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name) @@ -272,7 +334,7 @@ class FunctionCallStrategy(AgentPattern): name=tool_name, ) ) - return response_content, tool_files, tool_invoke_meta + return response_content, tool_files, tool_invoke_meta, False except Exception as e: # Tool invocation failed, yield error log error_message = str(e) @@ -293,4 +355,4 @@ class FunctionCallStrategy(AgentPattern): name=tool_name, ) ) - return error_content, [], None + return error_content, [], None, False