add tool choices for agent. (#1126)
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
||||
import pathlib
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
@@ -40,7 +40,10 @@ logger.info(f"args: {args}")
|
||||
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
|
||||
|
||||
|
||||
class AgentCompletionRequest(LLMParamsDoc):
|
||||
class AgentCompletionRequest(ChatCompletionRequest):
|
||||
# rewrite, specify tools in this turn of conversation
|
||||
tool_choice: Optional[List[str]] = None
|
||||
# for short/long term in-memory
|
||||
thread_id: str = "0"
|
||||
user_id: str = "0"
|
||||
|
||||
@@ -52,42 +55,40 @@ class AgentCompletionRequest(LLMParamsDoc):
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
)
|
||||
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]):
|
||||
async def llm_generate(input: AgentCompletionRequest):
|
||||
if logflag:
|
||||
logger.info(input)
|
||||
|
||||
input.stream = args.stream
|
||||
config = {"recursion_limit": args.recursion_limit}
|
||||
# don't use global stream setting
|
||||
# input.stream = args.stream
|
||||
config = {"recursion_limit": args.recursion_limit, "tool_choice": input.tool_choice}
|
||||
|
||||
if args.with_memory:
|
||||
if isinstance(input, AgentCompletionRequest):
|
||||
config["configurable"] = {"thread_id": input.thread_id}
|
||||
else:
|
||||
config["configurable"] = {"thread_id": "0"}
|
||||
config["configurable"] = {"thread_id": input.thread_id}
|
||||
|
||||
if logflag:
|
||||
logger.info(type(agent_inst))
|
||||
|
||||
if isinstance(input, LLMParamsDoc):
|
||||
# use query as input
|
||||
input_query = input.query
|
||||
# openai compatible input
|
||||
if isinstance(input.messages, str):
|
||||
messages = input.messages
|
||||
else:
|
||||
# openai compatible input
|
||||
if isinstance(input.messages, str):
|
||||
input_query = input.messages
|
||||
else:
|
||||
input_query = input.messages[-1]["content"]
|
||||
# TODO: need handle multi-turn messages
|
||||
messages = input.messages[-1]["content"]
|
||||
|
||||
# 2. prepare the input for the agent
|
||||
if input.stream:
|
||||
logger.info("-----------STREAMING-------------")
|
||||
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
agent_inst.stream_generator(messages, config),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
logger.info("-----------NOT STREAMING-------------")
|
||||
response = await agent_inst.non_streaming_run(input_query, config)
|
||||
response = await agent_inst.non_streaming_run(messages, config)
|
||||
logger.info("-----------Response-------------")
|
||||
return GeneratedDoc(text=response, prompt=input_query)
|
||||
return GeneratedDoc(text=response, prompt=messages)
|
||||
|
||||
|
||||
@register_microservice(
|
||||
|
||||
@@ -11,7 +11,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from ...global_var import threads_global_kv
|
||||
from ...utils import has_multi_tool_inputs, tool_renderer
|
||||
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
|
||||
from ..base_agent import BaseAgent
|
||||
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
|
||||
|
||||
@@ -136,7 +136,8 @@ class ReActAgentwithLanggraph(BaseAgent):
|
||||
# does not rely on langchain bind_tools API
|
||||
# since tgi and vllm still do not have very good support for tool calling like OpenAI
|
||||
|
||||
from typing import Annotated, Sequence, TypedDict
|
||||
import json
|
||||
from typing import Annotated, List, Optional, Sequence, TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
@@ -154,6 +155,7 @@ class AgentState(TypedDict):
|
||||
"""The state of the agent."""
|
||||
|
||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
tool_choice: Optional[List[str]] = None
|
||||
is_last_step: IsLastStep
|
||||
|
||||
|
||||
@@ -191,7 +193,11 @@ class ReActAgentNodeLlama:
|
||||
history = assemble_history(messages)
|
||||
print("@@@ History: ", history)
|
||||
|
||||
tools_descriptions = tool_renderer(self.tools)
|
||||
tools_used = self.tools
|
||||
if state["tool_choice"] is not None:
|
||||
tools_used = filter_tools(self.tools, state["tool_choice"])
|
||||
|
||||
tools_descriptions = tool_renderer(tools_used)
|
||||
print("@@@ Tools description: ", tools_descriptions)
|
||||
|
||||
# invoke chain
|
||||
@@ -279,21 +285,45 @@ class ReActAgentLlama(BaseAgent):
|
||||
|
||||
async def stream_generator(self, query, config):
|
||||
initial_state = self.prepare_initial_state(query)
|
||||
try:
|
||||
async for event in self.app.astream(initial_state, config=config):
|
||||
for node_name, node_state in event.items():
|
||||
yield f"--- CALL {node_name} ---\n"
|
||||
for k, v in node_state.items():
|
||||
if v is not None:
|
||||
yield f"{k}: {v}\n"
|
||||
if "tool_choice" in config:
|
||||
initial_state["tool_choice"] = config.pop("tool_choice")
|
||||
|
||||
try:
|
||||
async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
|
||||
event_type = event[0]
|
||||
data = event[1]
|
||||
if event_type == "updates":
|
||||
for node_name, node_state in data.items():
|
||||
print(f"--- CALL {node_name} node ---\n")
|
||||
for k, v in node_state.items():
|
||||
if v is not None:
|
||||
print(f"------- {k}, {v} -------\n\n")
|
||||
if node_name == "agent":
|
||||
if v[0].content == "":
|
||||
tool_names = []
|
||||
for tool_call in v[0].tool_calls:
|
||||
tool_names.append(tool_call["name"])
|
||||
result = {"tool": tool_names}
|
||||
else:
|
||||
result = {"content": [v[0].content.replace("\n\n", "\n")]}
|
||||
# ui needs this format
|
||||
yield f"data: {json.dumps(result)}\n\n"
|
||||
elif node_name == "tools":
|
||||
full_content = v[0].content
|
||||
tool_name = v[0].name
|
||||
result = {"tool": tool_name, "content": [full_content]}
|
||||
yield f"data: {json.dumps(result)}\n\n"
|
||||
if not full_content:
|
||||
continue
|
||||
|
||||
yield f"data: {repr(event)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
yield str(e)
|
||||
|
||||
async def non_streaming_run(self, query, config):
|
||||
initial_state = self.prepare_initial_state(query)
|
||||
if "tool_choice" in config:
|
||||
initial_state["tool_choice"] = config.pop("tool_choice")
|
||||
try:
|
||||
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
|
||||
message = s["messages"][-1]
|
||||
|
||||
@@ -86,6 +86,14 @@ def tool_renderer(tools):
|
||||
return "\n".join(tool_strings)
|
||||
|
||||
|
||||
def filter_tools(tools, tools_choices):
|
||||
tool_used = []
|
||||
for tool in tools:
|
||||
if tool.name in tools_choices:
|
||||
tool_used.append(tool)
|
||||
return tool_used
|
||||
|
||||
|
||||
def has_multi_tool_inputs(tools):
|
||||
ret = False
|
||||
for tool in tools:
|
||||
|
||||
@@ -4,9 +4,17 @@
|
||||
|
||||
# tool for unit test
|
||||
def search_web(query: str) -> str:
|
||||
"""Search the web for a given query."""
|
||||
"""Search the web knowledge for a given query."""
|
||||
ret_text = """
|
||||
The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project.
|
||||
OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG).
|
||||
"""
|
||||
return ret_text
|
||||
|
||||
|
||||
def search_weather(query: str) -> str:
|
||||
"""Search the weather for a given query."""
|
||||
ret_text = """
|
||||
It's clear.
|
||||
"""
|
||||
return ret_text
|
||||
|
||||
@@ -10,7 +10,7 @@ import requests
|
||||
def generate_answer_agent_api(url, prompt):
|
||||
proxies = {"http": ""}
|
||||
payload = {
|
||||
"query": prompt,
|
||||
"messages": prompt,
|
||||
}
|
||||
response = requests.post(url, json=payload, proxies=proxies)
|
||||
answer = response.json()["text"]
|
||||
@@ -21,7 +21,7 @@ def process_request(url, query, is_stream=False):
|
||||
proxies = {"http": ""}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"messages": query,
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user