add tool choices for agent. (#1126)

This commit is contained in:
lkk
2025-01-13 14:42:31 +08:00
committed by GitHub
parent fe24decd72
commit 3a7ccb0a75
5 changed files with 81 additions and 34 deletions

View File

@@ -5,7 +5,7 @@ import os
import pathlib import pathlib
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Union from typing import List, Optional, Union
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -40,7 +40,10 @@ logger.info(f"args: {args}")
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory) agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
class AgentCompletionRequest(LLMParamsDoc): class AgentCompletionRequest(ChatCompletionRequest):
# rewrite, specify tools in this turn of conversation
tool_choice: Optional[List[str]] = None
# for short/long term in-memory
thread_id: str = "0" thread_id: str = "0"
user_id: str = "0" user_id: str = "0"
@@ -52,42 +55,40 @@ class AgentCompletionRequest(LLMParamsDoc):
host="0.0.0.0", host="0.0.0.0",
port=args.port, port=args.port,
) )
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]): async def llm_generate(input: AgentCompletionRequest):
if logflag: if logflag:
logger.info(input) logger.info(input)
input.stream = args.stream # don't use global stream setting
config = {"recursion_limit": args.recursion_limit} # input.stream = args.stream
config = {"recursion_limit": args.recursion_limit, "tool_choice": input.tool_choice}
if args.with_memory: if args.with_memory:
if isinstance(input, AgentCompletionRequest): config["configurable"] = {"thread_id": input.thread_id}
config["configurable"] = {"thread_id": input.thread_id}
else:
config["configurable"] = {"thread_id": "0"}
if logflag: if logflag:
logger.info(type(agent_inst)) logger.info(type(agent_inst))
if isinstance(input, LLMParamsDoc): # openai compatible input
# use query as input if isinstance(input.messages, str):
input_query = input.query messages = input.messages
else: else:
# openai compatible input # TODO: need handle multi-turn messages
if isinstance(input.messages, str): messages = input.messages[-1]["content"]
input_query = input.messages
else:
input_query = input.messages[-1]["content"]
# 2. prepare the input for the agent # 2. prepare the input for the agent
if input.stream: if input.stream:
logger.info("-----------STREAMING-------------") logger.info("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream") return StreamingResponse(
agent_inst.stream_generator(messages, config),
media_type="text/event-stream",
)
else: else:
logger.info("-----------NOT STREAMING-------------") logger.info("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input_query, config) response = await agent_inst.non_streaming_run(messages, config)
logger.info("-----------Response-------------") logger.info("-----------Response-------------")
return GeneratedDoc(text=response, prompt=input_query) return GeneratedDoc(text=response, prompt=messages)
@register_microservice( @register_microservice(

View File

@@ -11,7 +11,7 @@ from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from ...global_var import threads_global_kv from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent from ..base_agent import BaseAgent
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
@@ -136,7 +136,8 @@ class ReActAgentwithLanggraph(BaseAgent):
# does not rely on langchain bind_tools API # does not rely on langchain bind_tools API
# since tgi and vllm still do not have very good support for tool calling like OpenAI # since tgi and vllm still do not have very good support for tool calling like OpenAI
from typing import Annotated, Sequence, TypedDict import json
from typing import Annotated, List, Optional, Sequence, TypedDict
from langchain_core.messages import AIMessage, BaseMessage from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
@@ -154,6 +155,7 @@ class AgentState(TypedDict):
"""The state of the agent.""" """The state of the agent."""
messages: Annotated[Sequence[BaseMessage], add_messages] messages: Annotated[Sequence[BaseMessage], add_messages]
tool_choice: Optional[List[str]] = None
is_last_step: IsLastStep is_last_step: IsLastStep
@@ -191,7 +193,11 @@ class ReActAgentNodeLlama:
history = assemble_history(messages) history = assemble_history(messages)
print("@@@ History: ", history) print("@@@ History: ", history)
tools_descriptions = tool_renderer(self.tools) tools_used = self.tools
if state["tool_choice"] is not None:
tools_used = filter_tools(self.tools, state["tool_choice"])
tools_descriptions = tool_renderer(tools_used)
print("@@@ Tools description: ", tools_descriptions) print("@@@ Tools description: ", tools_descriptions)
# invoke chain # invoke chain
@@ -279,21 +285,45 @@ class ReActAgentLlama(BaseAgent):
async def stream_generator(self, query, config): async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query) initial_state = self.prepare_initial_state(query)
try: if "tool_choice" in config:
async for event in self.app.astream(initial_state, config=config): initial_state["tool_choice"] = config.pop("tool_choice")
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n" try:
for k, v in node_state.items(): async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
if v is not None: event_type = event[0]
yield f"{k}: {v}\n" data = event[1]
if event_type == "updates":
for node_name, node_state in data.items():
print(f"--- CALL {node_name} node ---\n")
for k, v in node_state.items():
if v is not None:
print(f"------- {k}, {v} -------\n\n")
if node_name == "agent":
if v[0].content == "":
tool_names = []
for tool_call in v[0].tool_calls:
tool_names.append(tool_call["name"])
result = {"tool": tool_names}
else:
result = {"content": [v[0].content.replace("\n\n", "\n")]}
# ui needs this format
yield f"data: {json.dumps(result)}\n\n"
elif node_name == "tools":
full_content = v[0].content
tool_name = v[0].name
result = {"tool": tool_name, "content": [full_content]}
yield f"data: {json.dumps(result)}\n\n"
if not full_content:
continue
yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
except Exception as e: except Exception as e:
yield str(e) yield str(e)
async def non_streaming_run(self, query, config): async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query) initial_state = self.prepare_initial_state(query)
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")
try: try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"): async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1] message = s["messages"][-1]

View File

@@ -86,6 +86,14 @@ def tool_renderer(tools):
return "\n".join(tool_strings) return "\n".join(tool_strings)
def filter_tools(tools, tools_choices):
tool_used = []
for tool in tools:
if tool.name in tools_choices:
tool_used.append(tool)
return tool_used
def has_multi_tool_inputs(tools): def has_multi_tool_inputs(tools):
ret = False ret = False
for tool in tools: for tool in tools:

View File

@@ -4,9 +4,17 @@
# tool for unit test # tool for unit test
def search_web(query: str) -> str: def search_web(query: str) -> str:
"""Search the web for a given query.""" """Search the web knowledge for a given query."""
ret_text = """ ret_text = """
The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project. The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project.
OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG). OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG).
""" """
return ret_text return ret_text
def search_weather(query: str) -> str:
"""Search the weather for a given query."""
ret_text = """
It's clear.
"""
return ret_text

View File

@@ -10,7 +10,7 @@ import requests
def generate_answer_agent_api(url, prompt): def generate_answer_agent_api(url, prompt):
proxies = {"http": ""} proxies = {"http": ""}
payload = { payload = {
"query": prompt, "messages": prompt,
} }
response = requests.post(url, json=payload, proxies=proxies) response = requests.post(url, json=payload, proxies=proxies)
answer = response.json()["text"] answer = response.json()["text"]
@@ -21,7 +21,7 @@ def process_request(url, query, is_stream=False):
proxies = {"http": ""} proxies = {"http": ""}
payload = { payload = {
"query": query, "messages": query,
} }
try: try: