fix history content from agent memory. (#899)

* fix history content from agent memory.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
lkk
2024-11-14 21:26:01 +08:00
committed by GitHub
parent 0dbf57751b
commit 32bcde4528
2 changed files with 47 additions and 4 deletions

View File

@@ -147,6 +147,7 @@ from langgraph.prebuilt import ToolNode
from ...persistence import AgentPersistence, PersistenceConfig
from ...utils import setup_chat_model
from .utils import assemble_history, assemble_memory, convert_json_to_tool_call
class AgentState(TypedDict):
@@ -174,16 +175,20 @@ class ReActAgentNodeLlama:
llm = setup_chat_model(args)
self.tools = tools
self.chain = prompt | llm | output_parser
self.with_memory = args.with_memory
def __call__(self, state):
from .utils import assemble_history, convert_json_to_tool_call
print("---CALL Agent node---")
messages = state["messages"]
# assemble a prompt from messages
query = messages[0].content
history = assemble_history(messages)
if self.with_memory:
query, history = assemble_memory(messages)
print("@@@ Query: ", history)
else:
query = messages[0].content
history = assemble_history(messages)
print("@@@ History: ", history)
tools_descriptions = tool_renderer(self.tools)

View File

@@ -5,7 +5,7 @@ import json
import uuid
from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import BaseOutputParser
@@ -82,3 +82,41 @@ def assemble_history(messages):
query_history += f"Assistant Output: {m.content}\n"
return query_history
def assemble_memory(messages):
"""
messages: Human, AI, TOOL, AI, TOOL, etc.
"""
query = ""
query_id = None
query_history = ""
breaker = "-" * 10
# get query
for m in messages[::-1]:
if isinstance(m, HumanMessage):
query = m.content
query_id = m.id
break
for m in messages:
if isinstance(m, AIMessage):
# if there is tool call
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0:
for tool_call in m.tool_calls:
tool = tool_call["name"]
tc_args = tool_call["args"]
id = tool_call["id"]
tool_output = get_tool_output(messages, id)
query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n"
else:
# did not make tool calls
query_history += f"Assistant Output: {m.content}\n"
elif isinstance(m, HumanMessage):
if m.id == query_id:
continue
query_history += f"Human Input: {m.content}\n"
return query, query_history