add tool choices for agent. (#1126)

This commit is contained in:
lkk
2025-01-13 14:42:31 +08:00
committed by GitHub
parent fe24decd72
commit 3a7ccb0a75
5 changed files with 81 additions and 34 deletions

View File

@@ -5,7 +5,7 @@ import os
import pathlib
import sys
from datetime import datetime
from typing import Union
from typing import List, Optional, Union
from fastapi.responses import StreamingResponse
@@ -40,7 +40,10 @@ logger.info(f"args: {args}")
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
class AgentCompletionRequest(LLMParamsDoc):
class AgentCompletionRequest(ChatCompletionRequest):
# rewrite, specify tools in this turn of conversation
tool_choice: Optional[List[str]] = None
# for short/long term in-memory
thread_id: str = "0"
user_id: str = "0"
@@ -52,42 +55,40 @@ class AgentCompletionRequest(LLMParamsDoc):
host="0.0.0.0",
port=args.port,
)
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]):
async def llm_generate(input: AgentCompletionRequest):
if logflag:
logger.info(input)
input.stream = args.stream
config = {"recursion_limit": args.recursion_limit}
# don't use global stream setting
# input.stream = args.stream
config = {"recursion_limit": args.recursion_limit, "tool_choice": input.tool_choice}
if args.with_memory:
if isinstance(input, AgentCompletionRequest):
config["configurable"] = {"thread_id": input.thread_id}
else:
config["configurable"] = {"thread_id": "0"}
config["configurable"] = {"thread_id": input.thread_id}
if logflag:
logger.info(type(agent_inst))
if isinstance(input, LLMParamsDoc):
# use query as input
input_query = input.query
# openai compatible input
if isinstance(input.messages, str):
messages = input.messages
else:
# openai compatible input
if isinstance(input.messages, str):
input_query = input.messages
else:
input_query = input.messages[-1]["content"]
# TODO: need handle multi-turn messages
messages = input.messages[-1]["content"]
# 2. prepare the input for the agent
if input.stream:
logger.info("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream")
return StreamingResponse(
agent_inst.stream_generator(messages, config),
media_type="text/event-stream",
)
else:
logger.info("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input_query, config)
response = await agent_inst.non_streaming_run(messages, config)
logger.info("-----------Response-------------")
return GeneratedDoc(text=response, prompt=input_query)
return GeneratedDoc(text=response, prompt=messages)
@register_microservice(

View File

@@ -11,7 +11,7 @@ from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
@@ -136,7 +136,8 @@ class ReActAgentwithLanggraph(BaseAgent):
# does not rely on langchain bind_tools API
# since tgi and vllm still do not have very good support for tool calling like OpenAI
from typing import Annotated, Sequence, TypedDict
import json
from typing import Annotated, List, Optional, Sequence, TypedDict
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.prompts import PromptTemplate
@@ -154,6 +155,7 @@ class AgentState(TypedDict):
"""The state of the agent."""
messages: Annotated[Sequence[BaseMessage], add_messages]
tool_choice: Optional[List[str]] = None
is_last_step: IsLastStep
@@ -191,7 +193,11 @@ class ReActAgentNodeLlama:
history = assemble_history(messages)
print("@@@ History: ", history)
tools_descriptions = tool_renderer(self.tools)
tools_used = self.tools
if state["tool_choice"] is not None:
tools_used = filter_tools(self.tools, state["tool_choice"])
tools_descriptions = tool_renderer(tools_used)
print("@@@ Tools description: ", tools_descriptions)
# invoke chain
@@ -279,21 +285,45 @@ class ReActAgentLlama(BaseAgent):
async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
async for event in self.app.astream(initial_state, config=config):
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n"
for k, v in node_state.items():
if v is not None:
yield f"{k}: {v}\n"
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")
try:
async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
event_type = event[0]
data = event[1]
if event_type == "updates":
for node_name, node_state in data.items():
print(f"--- CALL {node_name} node ---\n")
for k, v in node_state.items():
if v is not None:
print(f"------- {k}, {v} -------\n\n")
if node_name == "agent":
if v[0].content == "":
tool_names = []
for tool_call in v[0].tool_calls:
tool_names.append(tool_call["name"])
result = {"tool": tool_names}
else:
result = {"content": [v[0].content.replace("\n\n", "\n")]}
# ui needs this format
yield f"data: {json.dumps(result)}\n\n"
elif node_name == "tools":
full_content = v[0].content
tool_name = v[0].name
result = {"tool": tool_name, "content": [full_content]}
yield f"data: {json.dumps(result)}\n\n"
if not full_content:
continue
yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)
async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]

View File

@@ -86,6 +86,14 @@ def tool_renderer(tools):
return "\n".join(tool_strings)
def filter_tools(tools, tools_choices):
tool_used = []
for tool in tools:
if tool.name in tools_choices:
tool_used.append(tool)
return tool_used
def has_multi_tool_inputs(tools):
ret = False
for tool in tools:

View File

@@ -4,9 +4,17 @@
# tool for unit test
def search_web(query: str) -> str:
"""Search the web for a given query."""
"""Search the web knowledge for a given query."""
ret_text = """
The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project.
OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG).
"""
return ret_text
def search_weather(query: str) -> str:
"""Search the weather for a given query."""
ret_text = """
It's clear.
"""
return ret_text

View File

@@ -10,7 +10,7 @@ import requests
def generate_answer_agent_api(url, prompt):
proxies = {"http": ""}
payload = {
"query": prompt,
"messages": prompt,
}
response = requests.post(url, json=payload, proxies=proxies)
answer = response.json()["text"]
@@ -21,7 +21,7 @@ def process_request(url, query, is_stream=False):
proxies = {"http": ""}
payload = {
"query": query,
"messages": query,
}
try: