add tool choices for agent. (#1126)
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
|||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
@@ -40,7 +40,10 @@ logger.info(f"args: {args}")
|
|||||||
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
|
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
|
||||||
|
|
||||||
|
|
||||||
class AgentCompletionRequest(LLMParamsDoc):
|
class AgentCompletionRequest(ChatCompletionRequest):
|
||||||
|
# rewrite, specify tools in this turn of conversation
|
||||||
|
tool_choice: Optional[List[str]] = None
|
||||||
|
# for short/long term in-memory
|
||||||
thread_id: str = "0"
|
thread_id: str = "0"
|
||||||
user_id: str = "0"
|
user_id: str = "0"
|
||||||
|
|
||||||
@@ -52,42 +55,40 @@ class AgentCompletionRequest(LLMParamsDoc):
|
|||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=args.port,
|
port=args.port,
|
||||||
)
|
)
|
||||||
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]):
|
async def llm_generate(input: AgentCompletionRequest):
|
||||||
if logflag:
|
if logflag:
|
||||||
logger.info(input)
|
logger.info(input)
|
||||||
|
|
||||||
input.stream = args.stream
|
# don't use global stream setting
|
||||||
config = {"recursion_limit": args.recursion_limit}
|
# input.stream = args.stream
|
||||||
|
config = {"recursion_limit": args.recursion_limit, "tool_choice": input.tool_choice}
|
||||||
|
|
||||||
if args.with_memory:
|
if args.with_memory:
|
||||||
if isinstance(input, AgentCompletionRequest):
|
config["configurable"] = {"thread_id": input.thread_id}
|
||||||
config["configurable"] = {"thread_id": input.thread_id}
|
|
||||||
else:
|
|
||||||
config["configurable"] = {"thread_id": "0"}
|
|
||||||
|
|
||||||
if logflag:
|
if logflag:
|
||||||
logger.info(type(agent_inst))
|
logger.info(type(agent_inst))
|
||||||
|
|
||||||
if isinstance(input, LLMParamsDoc):
|
# openai compatible input
|
||||||
# use query as input
|
if isinstance(input.messages, str):
|
||||||
input_query = input.query
|
messages = input.messages
|
||||||
else:
|
else:
|
||||||
# openai compatible input
|
# TODO: need handle multi-turn messages
|
||||||
if isinstance(input.messages, str):
|
messages = input.messages[-1]["content"]
|
||||||
input_query = input.messages
|
|
||||||
else:
|
|
||||||
input_query = input.messages[-1]["content"]
|
|
||||||
|
|
||||||
# 2. prepare the input for the agent
|
# 2. prepare the input for the agent
|
||||||
if input.stream:
|
if input.stream:
|
||||||
logger.info("-----------STREAMING-------------")
|
logger.info("-----------STREAMING-------------")
|
||||||
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream")
|
return StreamingResponse(
|
||||||
|
agent_inst.stream_generator(messages, config),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info("-----------NOT STREAMING-------------")
|
logger.info("-----------NOT STREAMING-------------")
|
||||||
response = await agent_inst.non_streaming_run(input_query, config)
|
response = await agent_inst.non_streaming_run(messages, config)
|
||||||
logger.info("-----------Response-------------")
|
logger.info("-----------Response-------------")
|
||||||
return GeneratedDoc(text=response, prompt=input_query)
|
return GeneratedDoc(text=response, prompt=messages)
|
||||||
|
|
||||||
|
|
||||||
@register_microservice(
|
@register_microservice(
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
|||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
from ...global_var import threads_global_kv
|
from ...global_var import threads_global_kv
|
||||||
from ...utils import has_multi_tool_inputs, tool_renderer
|
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
|
||||||
from ..base_agent import BaseAgent
|
from ..base_agent import BaseAgent
|
||||||
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
|
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
|
||||||
|
|
||||||
@@ -136,7 +136,8 @@ class ReActAgentwithLanggraph(BaseAgent):
|
|||||||
# does not rely on langchain bind_tools API
|
# does not rely on langchain bind_tools API
|
||||||
# since tgi and vllm still do not have very good support for tool calling like OpenAI
|
# since tgi and vllm still do not have very good support for tool calling like OpenAI
|
||||||
|
|
||||||
from typing import Annotated, Sequence, TypedDict
|
import json
|
||||||
|
from typing import Annotated, List, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage
|
from langchain_core.messages import AIMessage, BaseMessage
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
@@ -154,6 +155,7 @@ class AgentState(TypedDict):
|
|||||||
"""The state of the agent."""
|
"""The state of the agent."""
|
||||||
|
|
||||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||||
|
tool_choice: Optional[List[str]] = None
|
||||||
is_last_step: IsLastStep
|
is_last_step: IsLastStep
|
||||||
|
|
||||||
|
|
||||||
@@ -191,7 +193,11 @@ class ReActAgentNodeLlama:
|
|||||||
history = assemble_history(messages)
|
history = assemble_history(messages)
|
||||||
print("@@@ History: ", history)
|
print("@@@ History: ", history)
|
||||||
|
|
||||||
tools_descriptions = tool_renderer(self.tools)
|
tools_used = self.tools
|
||||||
|
if state["tool_choice"] is not None:
|
||||||
|
tools_used = filter_tools(self.tools, state["tool_choice"])
|
||||||
|
|
||||||
|
tools_descriptions = tool_renderer(tools_used)
|
||||||
print("@@@ Tools description: ", tools_descriptions)
|
print("@@@ Tools description: ", tools_descriptions)
|
||||||
|
|
||||||
# invoke chain
|
# invoke chain
|
||||||
@@ -279,21 +285,45 @@ class ReActAgentLlama(BaseAgent):
|
|||||||
|
|
||||||
async def stream_generator(self, query, config):
|
async def stream_generator(self, query, config):
|
||||||
initial_state = self.prepare_initial_state(query)
|
initial_state = self.prepare_initial_state(query)
|
||||||
try:
|
if "tool_choice" in config:
|
||||||
async for event in self.app.astream(initial_state, config=config):
|
initial_state["tool_choice"] = config.pop("tool_choice")
|
||||||
for node_name, node_state in event.items():
|
|
||||||
yield f"--- CALL {node_name} ---\n"
|
try:
|
||||||
for k, v in node_state.items():
|
async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
|
||||||
if v is not None:
|
event_type = event[0]
|
||||||
yield f"{k}: {v}\n"
|
data = event[1]
|
||||||
|
if event_type == "updates":
|
||||||
|
for node_name, node_state in data.items():
|
||||||
|
print(f"--- CALL {node_name} node ---\n")
|
||||||
|
for k, v in node_state.items():
|
||||||
|
if v is not None:
|
||||||
|
print(f"------- {k}, {v} -------\n\n")
|
||||||
|
if node_name == "agent":
|
||||||
|
if v[0].content == "":
|
||||||
|
tool_names = []
|
||||||
|
for tool_call in v[0].tool_calls:
|
||||||
|
tool_names.append(tool_call["name"])
|
||||||
|
result = {"tool": tool_names}
|
||||||
|
else:
|
||||||
|
result = {"content": [v[0].content.replace("\n\n", "\n")]}
|
||||||
|
# ui needs this format
|
||||||
|
yield f"data: {json.dumps(result)}\n\n"
|
||||||
|
elif node_name == "tools":
|
||||||
|
full_content = v[0].content
|
||||||
|
tool_name = v[0].name
|
||||||
|
result = {"tool": tool_name, "content": [full_content]}
|
||||||
|
yield f"data: {json.dumps(result)}\n\n"
|
||||||
|
if not full_content:
|
||||||
|
continue
|
||||||
|
|
||||||
yield f"data: {repr(event)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield str(e)
|
yield str(e)
|
||||||
|
|
||||||
async def non_streaming_run(self, query, config):
|
async def non_streaming_run(self, query, config):
|
||||||
initial_state = self.prepare_initial_state(query)
|
initial_state = self.prepare_initial_state(query)
|
||||||
|
if "tool_choice" in config:
|
||||||
|
initial_state["tool_choice"] = config.pop("tool_choice")
|
||||||
try:
|
try:
|
||||||
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
|
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
|
||||||
message = s["messages"][-1]
|
message = s["messages"][-1]
|
||||||
|
|||||||
@@ -86,6 +86,14 @@ def tool_renderer(tools):
|
|||||||
return "\n".join(tool_strings)
|
return "\n".join(tool_strings)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_tools(tools, tools_choices):
|
||||||
|
tool_used = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.name in tools_choices:
|
||||||
|
tool_used.append(tool)
|
||||||
|
return tool_used
|
||||||
|
|
||||||
|
|
||||||
def has_multi_tool_inputs(tools):
|
def has_multi_tool_inputs(tools):
|
||||||
ret = False
|
ret = False
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
|||||||
@@ -4,9 +4,17 @@
|
|||||||
|
|
||||||
# tool for unit test
|
# tool for unit test
|
||||||
def search_web(query: str) -> str:
|
def search_web(query: str) -> str:
|
||||||
"""Search the web for a given query."""
|
"""Search the web knowledge for a given query."""
|
||||||
ret_text = """
|
ret_text = """
|
||||||
The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project.
|
The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project.
|
||||||
OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG).
|
OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG).
|
||||||
"""
|
"""
|
||||||
return ret_text
|
return ret_text
|
||||||
|
|
||||||
|
|
||||||
|
def search_weather(query: str) -> str:
|
||||||
|
"""Search the weather for a given query."""
|
||||||
|
ret_text = """
|
||||||
|
It's clear.
|
||||||
|
"""
|
||||||
|
return ret_text
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import requests
|
|||||||
def generate_answer_agent_api(url, prompt):
|
def generate_answer_agent_api(url, prompt):
|
||||||
proxies = {"http": ""}
|
proxies = {"http": ""}
|
||||||
payload = {
|
payload = {
|
||||||
"query": prompt,
|
"messages": prompt,
|
||||||
}
|
}
|
||||||
response = requests.post(url, json=payload, proxies=proxies)
|
response = requests.post(url, json=payload, proxies=proxies)
|
||||||
answer = response.json()["text"]
|
answer = response.json()["text"]
|
||||||
@@ -21,7 +21,7 @@ def process_request(url, query, is_stream=False):
|
|||||||
proxies = {"http": ""}
|
proxies = {"http": ""}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"query": query,
|
"messages": query,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user