fix history content from agent memory. (#899)
* fix history content from agent memory. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -147,6 +147,7 @@ from langgraph.prebuilt import ToolNode
|
||||
|
||||
from ...persistence import AgentPersistence, PersistenceConfig
|
||||
from ...utils import setup_chat_model
|
||||
from .utils import assemble_history, assemble_memory, convert_json_to_tool_call
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
@@ -174,16 +175,20 @@ class ReActAgentNodeLlama:
|
||||
llm = setup_chat_model(args)
|
||||
self.tools = tools
|
||||
self.chain = prompt | llm | output_parser
|
||||
self.with_memory = args.with_memory
|
||||
|
||||
def __call__(self, state):
|
||||
from .utils import assemble_history, convert_json_to_tool_call
|
||||
|
||||
print("---CALL Agent node---")
|
||||
messages = state["messages"]
|
||||
|
||||
# assemble a prompt from messages
|
||||
query = messages[0].content
|
||||
history = assemble_history(messages)
|
||||
if self.with_memory:
|
||||
query, history = assemble_memory(messages)
|
||||
print("@@@ Query: ", history)
|
||||
else:
|
||||
query = messages[0].content
|
||||
history = assemble_history(messages)
|
||||
print("@@@ History: ", history)
|
||||
|
||||
tools_descriptions = tool_renderer(self.tools)
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import uuid
|
||||
|
||||
from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
|
||||
@@ -82,3 +82,41 @@ def assemble_history(messages):
|
||||
query_history += f"Assistant Output: {m.content}\n"
|
||||
|
||||
return query_history
|
||||
|
||||
|
||||
def assemble_memory(messages):
|
||||
"""
|
||||
messages: Human, AI, TOOL, AI, TOOL, etc.
|
||||
"""
|
||||
query = ""
|
||||
query_id = None
|
||||
query_history = ""
|
||||
breaker = "-" * 10
|
||||
|
||||
# get query
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m, HumanMessage):
|
||||
query = m.content
|
||||
query_id = m.id
|
||||
break
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m, AIMessage):
|
||||
# if there is tool call
|
||||
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0:
|
||||
for tool_call in m.tool_calls:
|
||||
tool = tool_call["name"]
|
||||
tc_args = tool_call["args"]
|
||||
id = tool_call["id"]
|
||||
tool_output = get_tool_output(messages, id)
|
||||
query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n"
|
||||
else:
|
||||
# did not make tool calls
|
||||
query_history += f"Assistant Output: {m.content}\n"
|
||||
|
||||
elif isinstance(m, HumanMessage):
|
||||
if m.id == query_id:
|
||||
continue
|
||||
query_history += f"Human Input: {m.content}\n"
|
||||
|
||||
return query, query_history
|
||||
|
||||
Reference in New Issue
Block a user