340 lines
14 KiB
Python
340 lines
14 KiB
Python
# Copyright (C) 2024 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from langchain.agents import AgentExecutor
|
|
from langchain.agents import create_react_agent as create_react_langchain_agent
|
|
from langchain.memory import ChatMessageHistory
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
from langchain_huggingface import ChatHuggingFace
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
|
from ...global_var import threads_global_kv
|
|
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
|
|
from ..base_agent import BaseAgent
|
|
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
|
|
|
|
|
|
class ReActAgentwithLangchain(BaseAgent):
|
|
def __init__(self, args, with_memory=False, **kwargs):
|
|
super().__init__(args, local_vars=globals(), **kwargs)
|
|
prompt = hwchase17_react_prompt
|
|
if has_multi_tool_inputs(self.tools_descriptions):
|
|
raise ValueError("Only supports single input tools when using strategy == react_langchain")
|
|
else:
|
|
agent_chain = create_react_langchain_agent(
|
|
self.llm, self.tools_descriptions, prompt, tools_renderer=tool_renderer
|
|
)
|
|
self.app = AgentExecutor(
|
|
agent=agent_chain, tools=self.tools_descriptions, verbose=True, handle_parsing_errors=True
|
|
)
|
|
self.memory = {}
|
|
|
|
def get_session_history(session_id):
|
|
if session_id in self.memory:
|
|
return self.memory[session_id]
|
|
else:
|
|
mem = ChatMessageHistory()
|
|
self.memory[session_id] = mem
|
|
return mem
|
|
|
|
if with_memory:
|
|
self.app = RunnableWithMessageHistory(
|
|
self.app,
|
|
get_session_history,
|
|
input_messages_key="input",
|
|
history_messages_key="chat_history",
|
|
history_factory_config=[],
|
|
)
|
|
|
|
def prepare_initial_state(self, query):
|
|
return {"input": query}
|
|
|
|
async def stream_generator(self, query, config, thread_id=None):
|
|
initial_state = self.prepare_initial_state(query)
|
|
if thread_id is not None:
|
|
config["configurable"] = {"session_id": thread_id}
|
|
async for chunk in self.app.astream(initial_state, config=config):
|
|
if thread_id is not None:
|
|
with threads_global_kv as g_threads:
|
|
thread_inst, created_at, status = g_threads[thread_id]
|
|
if status == "try_cancel":
|
|
yield "[thread_completion_callback] signal to cancel! Changed status to ready"
|
|
print("[thread_completion_callback] signal to cancel! Changed status to ready")
|
|
g_threads[thread_id] = (thread_inst, created_at, "ready")
|
|
break
|
|
if "actions" in chunk:
|
|
for action in chunk["actions"]:
|
|
yield f"Calling Tool: `{action.tool}` with input `{action.tool_input}`\n\n"
|
|
# Observation
|
|
elif "steps" in chunk:
|
|
for step in chunk["steps"]:
|
|
yield f"Tool Result: `{step.observation}`\n\n"
|
|
# Final result
|
|
elif "output" in chunk:
|
|
yield f"data: {repr(chunk['output'])}\n\n"
|
|
else:
|
|
raise ValueError()
|
|
print("---")
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
class ReActAgentwithLanggraph(BaseAgent):
|
|
def __init__(self, args, with_memory=False, **kwargs):
|
|
super().__init__(args, local_vars=globals(), **kwargs)
|
|
|
|
tools = self.tools_descriptions
|
|
print("REACT_SYS_MESSAGE: ", REACT_SYS_MESSAGE)
|
|
|
|
if with_memory:
|
|
self.app = create_react_agent(
|
|
self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE, checkpointer=MemorySaver()
|
|
)
|
|
else:
|
|
self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE)
|
|
|
|
def prepare_initial_state(self, query):
|
|
return {"messages": [HumanMessage(content=query)]}
|
|
|
|
async def stream_generator(self, query, config):
|
|
initial_state = self.prepare_initial_state(query)
|
|
try:
|
|
async for event in self.app.astream(initial_state, config=config):
|
|
for node_name, node_state in event.items():
|
|
yield f"--- CALL {node_name} ---\n"
|
|
for k, v in node_state.items():
|
|
if v is not None:
|
|
yield f"{k}: {v}\n"
|
|
|
|
yield f"data: {repr(event)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
except Exception as e:
|
|
yield str(e)
|
|
|
|
async def non_streaming_run(self, query, config):
|
|
initial_state = self.prepare_initial_state(query)
|
|
try:
|
|
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
|
|
message = s["messages"][-1]
|
|
if isinstance(message, tuple):
|
|
print(message)
|
|
else:
|
|
message.pretty_print()
|
|
|
|
last_message = s["messages"][-1]
|
|
print("******Response: ", last_message.content)
|
|
return last_message.content
|
|
except Exception as e:
|
|
return str(e)
|
|
|
|
|
|
###############################################################################
|
|
# ReActAgentLlama:
|
|
# Only validated with with Llama3.1-70B-Instruct model served with TGI-gaudi
|
|
# support multiple tools
|
|
# does not rely on langchain bind_tools API
|
|
# since tgi and vllm still do not have very good support for tool calling like OpenAI
|
|
|
|
import json
|
|
from typing import Annotated, List, Optional, Sequence, TypedDict
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage
|
|
from langchain_core.prompts import PromptTemplate
|
|
from langgraph.graph import END, StateGraph
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.managed import IsLastStep
|
|
from langgraph.prebuilt import ToolNode
|
|
|
|
from ...persistence import AgentPersistence, PersistenceConfig
|
|
from ...utils import setup_chat_model
|
|
from .utils import assemble_history, assemble_memory, convert_json_to_tool_call
|
|
|
|
|
|
class AgentState(TypedDict):
|
|
"""The state of the agent."""
|
|
|
|
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
tool_choice: Optional[List[str]] = None
|
|
is_last_step: IsLastStep
|
|
|
|
|
|
class ReActAgentNodeLlama:
|
|
"""Do planning and reasoning and generate tool calls.
|
|
|
|
A workaround for open-source llm served by TGI-gaudi.
|
|
"""
|
|
|
|
def __init__(self, tools, args):
|
|
from .prompt import REACT_AGENT_LLAMA_PROMPT
|
|
from .utils import ReActLlamaOutputParser
|
|
|
|
output_parser = ReActLlamaOutputParser()
|
|
prompt = PromptTemplate(
|
|
template=REACT_AGENT_LLAMA_PROMPT,
|
|
input_variables=["input", "history", "tools"],
|
|
)
|
|
llm = setup_chat_model(args)
|
|
self.tools = tools
|
|
self.chain = prompt | llm | output_parser
|
|
self.with_memory = args.with_memory
|
|
|
|
def __call__(self, state):
|
|
|
|
print("---CALL Agent node---")
|
|
messages = state["messages"]
|
|
|
|
# assemble a prompt from messages
|
|
if self.with_memory:
|
|
query, history = assemble_memory(messages)
|
|
print("@@@ Query: ", history)
|
|
else:
|
|
query = messages[0].content
|
|
history = assemble_history(messages)
|
|
print("@@@ History: ", history)
|
|
|
|
tools_used = self.tools
|
|
if state["tool_choice"] is not None:
|
|
tools_used = filter_tools(self.tools, state["tool_choice"])
|
|
|
|
tools_descriptions = tool_renderer(tools_used)
|
|
print("@@@ Tools description: ", tools_descriptions)
|
|
|
|
# invoke chain
|
|
output = self.chain.invoke({"input": query, "history": history, "tools": tools_descriptions})
|
|
print("@@@ Output from chain: ", output)
|
|
|
|
# convert output to tool calls
|
|
tool_calls = []
|
|
for res in output:
|
|
if "tool" in res:
|
|
add_kw_tc, tool_call = convert_json_to_tool_call(res)
|
|
# print("Tool call:\n", tool_call)
|
|
tool_calls.append(tool_call)
|
|
|
|
if tool_calls:
|
|
ai_message = AIMessage(content="", additional_kwargs=add_kw_tc, tool_calls=tool_calls)
|
|
elif "answer" in output[0]:
|
|
ai_message = AIMessage(content=str(output[0]["answer"]))
|
|
else:
|
|
ai_message = AIMessage(content=output)
|
|
return {"messages": [ai_message]}
|
|
|
|
|
|
class ReActAgentLlama(BaseAgent):
|
|
def __init__(self, args, with_memory=False, **kwargs):
|
|
super().__init__(args, local_vars=globals(), **kwargs)
|
|
agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args)
|
|
tool_node = ToolNode(self.tools_descriptions)
|
|
|
|
workflow = StateGraph(AgentState)
|
|
|
|
# Define the nodes we will cycle between
|
|
workflow.add_node("agent", agent)
|
|
workflow.add_node("tools", tool_node)
|
|
|
|
workflow.set_entry_point("agent")
|
|
|
|
# We now add a conditional edge
|
|
workflow.add_conditional_edges(
|
|
# First, we define the start node. We use `agent`.
|
|
# This means these are the edges taken after the `agent` node is called.
|
|
"agent",
|
|
# Next, we pass in the function that will determine which node is called next.
|
|
self.should_continue,
|
|
# Finally we pass in a mapping.
|
|
# The keys are strings, and the values are other nodes.
|
|
# END is a special node marking that the graph should finish.
|
|
# What will happen is we will call `should_continue`, and then the output of that
|
|
# will be matched against the keys in this mapping.
|
|
# Based on which one it matches, that node will then be called.
|
|
{
|
|
# If `tools`, then we call the tool node.
|
|
"continue": "tools",
|
|
# Otherwise we finish.
|
|
"end": END,
|
|
},
|
|
)
|
|
|
|
# We now add a normal edge from `tools` to `agent`.
|
|
# This means that after `tools` is called, `agent` node is called next.
|
|
workflow.add_edge("tools", "agent")
|
|
|
|
if args.with_memory:
|
|
self.persistence = AgentPersistence(
|
|
config=PersistenceConfig(checkpointer=args.with_memory, store=args.with_store)
|
|
)
|
|
print(self.persistence.checkpointer)
|
|
self.app = workflow.compile(checkpointer=self.persistence.checkpointer, store=self.persistence.store)
|
|
else:
|
|
self.app = workflow.compile()
|
|
|
|
# Define the function that determines whether to continue or not
|
|
def should_continue(self, state: AgentState):
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
# If there is no function call, then we finish
|
|
if not last_message.tool_calls:
|
|
return "end"
|
|
# Otherwise if there is, we continue
|
|
else:
|
|
return "continue"
|
|
|
|
def prepare_initial_state(self, query):
|
|
return {"messages": [HumanMessage(content=query)]}
|
|
|
|
async def stream_generator(self, query, config):
|
|
initial_state = self.prepare_initial_state(query)
|
|
if "tool_choice" in config:
|
|
initial_state["tool_choice"] = config.pop("tool_choice")
|
|
|
|
try:
|
|
async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
|
|
event_type = event[0]
|
|
data = event[1]
|
|
if event_type == "updates":
|
|
for node_name, node_state in data.items():
|
|
print(f"--- CALL {node_name} node ---\n")
|
|
for k, v in node_state.items():
|
|
if v is not None:
|
|
print(f"------- {k}, {v} -------\n\n")
|
|
if node_name == "agent":
|
|
if v[0].content == "":
|
|
tool_names = []
|
|
for tool_call in v[0].tool_calls:
|
|
tool_names.append(tool_call["name"])
|
|
result = {"tool": tool_names}
|
|
else:
|
|
result = {"content": [v[0].content.replace("\n\n", "\n")]}
|
|
# ui needs this format
|
|
yield f"data: {json.dumps(result)}\n\n"
|
|
elif node_name == "tools":
|
|
full_content = v[0].content
|
|
tool_name = v[0].name
|
|
result = {"tool": tool_name, "content": [full_content]}
|
|
yield f"data: {json.dumps(result)}\n\n"
|
|
if not full_content:
|
|
continue
|
|
|
|
yield "data: [DONE]\n\n"
|
|
except Exception as e:
|
|
yield str(e)
|
|
|
|
async def non_streaming_run(self, query, config):
|
|
initial_state = self.prepare_initial_state(query)
|
|
if "tool_choice" in config:
|
|
initial_state["tool_choice"] = config.pop("tool_choice")
|
|
try:
|
|
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
|
|
message = s["messages"][-1]
|
|
if isinstance(message, tuple):
|
|
print(message)
|
|
else:
|
|
message.pretty_print()
|
|
|
|
last_message = s["messages"][-1]
|
|
print("******Response: ", last_message.content)
|
|
return last_message.content
|
|
except Exception as e:
|
|
return str(e)
|