Files
GenAIExamples/comps/agent/src/integrations/strategy/react/planner.py
2025-01-13 14:42:31 +08:00

340 lines
14 KiB
Python

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from langchain.agents import AgentExecutor
from langchain.agents import create_react_agent as create_react_langchain_agent
from langchain.memory import ChatMessageHistory
from langchain_core.messages import HumanMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_huggingface import ChatHuggingFace
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from ...global_var import threads_global_kv
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
class ReActAgentwithLangchain(BaseAgent):
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
prompt = hwchase17_react_prompt
if has_multi_tool_inputs(self.tools_descriptions):
raise ValueError("Only supports single input tools when using strategy == react_langchain")
else:
agent_chain = create_react_langchain_agent(
self.llm, self.tools_descriptions, prompt, tools_renderer=tool_renderer
)
self.app = AgentExecutor(
agent=agent_chain, tools=self.tools_descriptions, verbose=True, handle_parsing_errors=True
)
self.memory = {}
def get_session_history(session_id):
if session_id in self.memory:
return self.memory[session_id]
else:
mem = ChatMessageHistory()
self.memory[session_id] = mem
return mem
if with_memory:
self.app = RunnableWithMessageHistory(
self.app,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
history_factory_config=[],
)
def prepare_initial_state(self, query):
return {"input": query}
async def stream_generator(self, query, config, thread_id=None):
initial_state = self.prepare_initial_state(query)
if thread_id is not None:
config["configurable"] = {"session_id": thread_id}
async for chunk in self.app.astream(initial_state, config=config):
if thread_id is not None:
with threads_global_kv as g_threads:
thread_inst, created_at, status = g_threads[thread_id]
if status == "try_cancel":
yield "[thread_completion_callback] signal to cancel! Changed status to ready"
print("[thread_completion_callback] signal to cancel! Changed status to ready")
g_threads[thread_id] = (thread_inst, created_at, "ready")
break
if "actions" in chunk:
for action in chunk["actions"]:
yield f"Calling Tool: `{action.tool}` with input `{action.tool_input}`\n\n"
# Observation
elif "steps" in chunk:
for step in chunk["steps"]:
yield f"Tool Result: `{step.observation}`\n\n"
# Final result
elif "output" in chunk:
yield f"data: {repr(chunk['output'])}\n\n"
else:
raise ValueError()
print("---")
yield "data: [DONE]\n\n"
class ReActAgentwithLanggraph(BaseAgent):
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
tools = self.tools_descriptions
print("REACT_SYS_MESSAGE: ", REACT_SYS_MESSAGE)
if with_memory:
self.app = create_react_agent(
self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE, checkpointer=MemorySaver()
)
else:
self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE)
def prepare_initial_state(self, query):
return {"messages": [HumanMessage(content=query)]}
async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
async for event in self.app.astream(initial_state, config=config):
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n"
for k, v in node_state.items():
if v is not None:
yield f"{k}: {v}\n"
yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)
async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()
last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
except Exception as e:
return str(e)
###############################################################################
# ReActAgentLlama:
# Only validated with with Llama3.1-70B-Instruct model served with TGI-gaudi
# support multiple tools
# does not rely on langchain bind_tools API
# since tgi and vllm still do not have very good support for tool calling like OpenAI
import json
from typing import Annotated, List, Optional, Sequence, TypedDict
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.prompts import PromptTemplate
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
from langgraph.prebuilt import ToolNode
from ...persistence import AgentPersistence, PersistenceConfig
from ...utils import setup_chat_model
from .utils import assemble_history, assemble_memory, convert_json_to_tool_call
class AgentState(TypedDict):
"""The state of the agent."""
messages: Annotated[Sequence[BaseMessage], add_messages]
tool_choice: Optional[List[str]] = None
is_last_step: IsLastStep
class ReActAgentNodeLlama:
"""Do planning and reasoning and generate tool calls.
A workaround for open-source llm served by TGI-gaudi.
"""
def __init__(self, tools, args):
from .prompt import REACT_AGENT_LLAMA_PROMPT
from .utils import ReActLlamaOutputParser
output_parser = ReActLlamaOutputParser()
prompt = PromptTemplate(
template=REACT_AGENT_LLAMA_PROMPT,
input_variables=["input", "history", "tools"],
)
llm = setup_chat_model(args)
self.tools = tools
self.chain = prompt | llm | output_parser
self.with_memory = args.with_memory
def __call__(self, state):
print("---CALL Agent node---")
messages = state["messages"]
# assemble a prompt from messages
if self.with_memory:
query, history = assemble_memory(messages)
print("@@@ Query: ", history)
else:
query = messages[0].content
history = assemble_history(messages)
print("@@@ History: ", history)
tools_used = self.tools
if state["tool_choice"] is not None:
tools_used = filter_tools(self.tools, state["tool_choice"])
tools_descriptions = tool_renderer(tools_used)
print("@@@ Tools description: ", tools_descriptions)
# invoke chain
output = self.chain.invoke({"input": query, "history": history, "tools": tools_descriptions})
print("@@@ Output from chain: ", output)
# convert output to tool calls
tool_calls = []
for res in output:
if "tool" in res:
add_kw_tc, tool_call = convert_json_to_tool_call(res)
# print("Tool call:\n", tool_call)
tool_calls.append(tool_call)
if tool_calls:
ai_message = AIMessage(content="", additional_kwargs=add_kw_tc, tool_calls=tool_calls)
elif "answer" in output[0]:
ai_message = AIMessage(content=str(output[0]["answer"]))
else:
ai_message = AIMessage(content=output)
return {"messages": [ai_message]}
class ReActAgentLlama(BaseAgent):
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args)
tool_node = ToolNode(self.tools_descriptions)
workflow = StateGraph(AgentState)
# Define the nodes we will cycle between
workflow.add_node("agent", agent)
workflow.add_node("tools", tool_node)
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
self.should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
if args.with_memory:
self.persistence = AgentPersistence(
config=PersistenceConfig(checkpointer=args.with_memory, store=args.with_store)
)
print(self.persistence.checkpointer)
self.app = workflow.compile(checkpointer=self.persistence.checkpointer, store=self.persistence.store)
else:
self.app = workflow.compile()
# Define the function that determines whether to continue or not
def should_continue(self, state: AgentState):
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
def prepare_initial_state(self, query):
return {"messages": [HumanMessage(content=query)]}
async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")
try:
async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
event_type = event[0]
data = event[1]
if event_type == "updates":
for node_name, node_state in data.items():
print(f"--- CALL {node_name} node ---\n")
for k, v in node_state.items():
if v is not None:
print(f"------- {k}, {v} -------\n\n")
if node_name == "agent":
if v[0].content == "":
tool_names = []
for tool_call in v[0].tool_calls:
tool_names.append(tool_call["name"])
result = {"tool": tool_names}
else:
result = {"content": [v[0].content.replace("\n\n", "\n")]}
# ui needs this format
yield f"data: {json.dumps(result)}\n\n"
elif node_name == "tools":
full_content = v[0].content
tool_name = v[0].name
result = {"tool": tool_name, "content": [full_content]}
yield f"data: {json.dumps(result)}\n\n"
if not full_content:
continue
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)
async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()
last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
except Exception as e:
return str(e)