This reverts commit c36c5032dc.
Co-authored-by: lkk <33276950+lkk12014402@users.noreply.github.com>
This commit is contained in:
@@ -11,11 +11,10 @@ We currently support the following types of agents:
|
||||
1. ReAct: use `react_langchain` or `react_langgraph` or `react_llama` as strategy. First introduced in this seminal [paper](https://arxiv.org/abs/2210.03629). The ReAct agent engages in "reason-act-observe" cycles to solve problems. Please refer to this [doc](https://python.langchain.com/v0.2/docs/how_to/migrate_agent/) to understand the differences between the langchain and langgraph versions of react agents. See table below to understand the validated LLMs for each react strategy.
|
||||
2. RAG agent: use `rag_agent` or `rag_agent_llama` strategy. This agent is specifically designed for improving RAG performance. It has the capability to rephrase query, check relevancy of retrieved context, and iterate if context is not relevant. See table below to understand the validated LLMs for each rag agent strategy.
|
||||
3. Plan and execute: `plan_execute` strategy. This type of agent first makes a step-by-step plan given a user request, and then execute the plan sequentially (or in parallel, to be implemented in future). If the execution results can solve the problem, then the agent will output an answer; otherwise, it will replan and execute again.
|
||||
4. SQL agent: use `sql_agent_llama` or `sql_agent` strategy. This agent is specifically designed and optimized for answering questions aabout data in SQL databases. For more technical details read descriptions [here](src/strategy/sqlagent/README.md).
|
||||
|
||||
**Note**:
|
||||
|
||||
1. Due to the limitations in support for tool calling by TGI and vllm, we have developed subcategories of agent strategies (`rag_agent_llama`, `react_llama` and `sql_agent_llama`) specifically designed for open-source LLMs served with TGI and vllm.
|
||||
1. Due to the limitations in support for tool calling by TGI and vllm, we have developed subcategories of agent strategies (`rag_agent_llama` and `react_llama`) specifically designed for open-source LLMs served with TGI and vllm.
|
||||
2. For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below.
|
||||
|
||||
### 1.2 LLM engine
|
||||
@@ -26,16 +25,14 @@ Agents use LLM for reasoning and planning. We support 3 options of LLM engine:
|
||||
2. Open-source LLMs served with vllm. Follow the instructions in [Section 2.2.2](#222-start-agent-microservices-with-vllm).
|
||||
3. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=<your-openai-key>`
|
||||
|
||||
| Agent type | `strategy` arg | Validated LLMs (serving SW) | Notes |
|
||||
| ---------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ReAct | `react_langchain` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Only allows tools with one input variable |
|
||||
| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), | if using vllm, need to specify `--enable-auto-tool-choice --tool-call-parser ${model_parser}`, refer to vllm docs for more info |
|
||||
| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) (vllm-gaudi) | Recommended for open-source LLMs |
|
||||
| RAG agent | `rag_agent` | GPT-4o-mini | |
|
||||
| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) (vllm-gaudi) | Recommended for open-source LLMs, only allows 1 tool with input variable to be "query" |
|
||||
| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) (vllm-gaudi) | Currently, due to some issues with guaided decoding of vllm-gaudi, this strategy does not work properly with vllm-gaudi. We are actively debugging. Stay tuned. In the meanwhile, you can use OpenAI's models with this strategy. |
|
||||
| SQL agent | `sql_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (vllm-gaudi) | database query tool is natively integrated using Langchain's [QuerySQLDataBaseTool](https://python.langchain.com/api_reference/community/tools/langchain_community.tools.sql_database.tool.QuerySQLDataBaseTool.html#langchain_community.tools.sql_database.tool.QuerySQLDataBaseTool). User can also register their own tools with this agent. |
|
||||
| SQL agent | `sql_agent` | GPT-4o-mini | database query tool is natively integrated using Langchain's [QuerySQLDataBaseTool](https://python.langchain.com/api_reference/community/tools/langchain_community.tools.sql_database.tool.QuerySQLDataBaseTool.html#langchain_community.tools.sql_database.tool.QuerySQLDataBaseTool). User can also register their own tools with this agent. |
|
||||
| Agent type | `strategy` arg | Validated LLMs (serving SW) | Notes |
|
||||
| ---------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ReAct | `react_langchain` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Only allows tools with one input variable |
|
||||
| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), | if using vllm, need to specify `--enable-auto-tool-choice --tool-call-parser ${model_parser}`, refer to vllm docs for more info |
|
||||
| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs |
|
||||
| RAG agent | `rag_agent` | GPT-4o-mini | |
|
||||
| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs, only allows 1 tool with input variable to be "query" |
|
||||
| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) (vllm-gaudi) | |
|
||||
|
||||
### 1.3 Tools
|
||||
|
||||
@@ -123,12 +120,12 @@ Once microservice starts, user can use below script to invoke.
|
||||
|
||||
```bash
|
||||
curl http://${ip_address}:9090/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{
|
||||
"query": "What is OPEA project?"
|
||||
"query": "What is the weather today in Austin?"
|
||||
}'
|
||||
|
||||
# expected output
|
||||
|
||||
data: 'The OPEA project is .....</s>' # just showing partial example here.
|
||||
data: 'The temperature in Austin today is 78°F.</s>'
|
||||
|
||||
data: [DONE]
|
||||
|
||||
@@ -213,4 +210,4 @@ data: [DONE]
|
||||
## 5. Customize agent strategy
|
||||
|
||||
For advanced developers who want to implement their own agent strategies, you can add a separate folder in `src\strategy`, implement your agent by inherit the `BaseAgent` class, and add your strategy into the `src\agent.py`. The architecture of this agent microservice is shown in the diagram below as a reference.
|
||||

|
||||

|
||||
|
||||
|
Before Width: | Height: | Size: 740 KiB After Width: | Height: | Size: 740 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 246 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 284 KiB |
@@ -1,46 +0,0 @@
|
||||
[HumanMessage(content="what's the most recent album from the founder of ysl records?", id='cfde4aba-0464-4ad9-bd1c-d3fc40bbb46e'), AIMessage(content='', addi
|
||||
tional_kwargs={'tool_calls': [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'query': 'founder of YSL Records'}, name
|
||||
e='duckduckgo_search', description=None), id='142bd66e-7ec6-4381-bcb6-2d0bd4fcecd3', type='function')]}, id='6430421d-2238-452b-9fbe-4d6cc8bc0cd8', tool_call
|
||||
s=[{'name': 'duckduckgo_search', 'args': {'query': 'founder of YSL Records'}, 'id': '142bd66e-7ec6-4381-bcb6-2d0bd4fcecd3', 'type': 'tool_call'}]), ToolMessa
|
||||
ge(content='Prosecutors allege the chart-topping artist\'s label Young Stoner Life Records also stands for Young Slime ... a YSL co-founder who reached a ple
|
||||
a deal in 2022 and claimed Young Thug was a part ... In 2016, Young Thug founded YSL Records, with its full name, "Young Stoner Life," perfectly representing
|
||||
the enigmatic persona of the label. YSL Records quickly became known for its ... The longest criminal trial in the State of Georgia\'s history has been plag
|
||||
ued by endless problems, including the recusal of not one but two judges in charge of the case. Young Thug and his lawyer ... Prosecutors say YSL - the acron
|
||||
ym for the artist\'s label, Young Stoner Life Records - also stands for Young Slime Life, an Atlanta-based street gang affiliated with the national Bloods ga
|
||||
ng. YSL Records calls its roster of artists the "Slime Family." ... YSL co-founder Walter Murphy entered a guilty plea on a single count of conspiracy to vio
|
||||
late the state\'s Racketeer Influenced ...', name='duckduckgo_search', id='1d9ebf6b-bf62-469a-8d45-365c0fa15bba', tool_call_id='142bd66e-7ec6-4381-bcb6-2d0bd
|
||||
4fcecd3'), HumanMessage(content='Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again.'
|
||||
, id='6f2fb9db-e5b4-4483-b69a-06122c996eb9'), AIMessage(content='', additional_kwargs={'tool_calls': [ChatCompletionOutputToolCall(function=ChatCompletionOut
|
||||
putFunctionDefinition(arguments={'query': 'YSL Records founder'}, name='duckduckgo_search', description=None), id='6d073578-5bf1-449e-8d32-4f3b5f999a06', typ
|
||||
e='function')]}, id='50a644af-3117-4513-be56-5fb4b16d0481', tool_calls=[{'name': 'duckduckgo_search', 'args': {'query': 'YSL Records founder'}, 'id': '6d0735
|
||||
78-5bf1-449e-8d32-4f3b5f999a06', 'type': 'tool_call'}]), ToolMessage(content='That same year, he founded the YSL record label, which the rapper has used to p
|
||||
ropel close friends and family members to industry success. The Birth Of YSL Records And Its Impact On The Music Industry In 2016, Young Thug founded YSL Rec
|
||||
ords, with its full name, "Young Stoner Life," perfectly representing the enigmatic persona of ... Young Thug, who runs the Young Stoner Life label, has been
|
||||
accused of co-founding the Young Slime Life Atlanta gang and violating the RICO act, among other charges. Here\'s what to know about the ... The rapper foun
|
||||
dead the record label Young Stoner Life in 2016 as an imprint of 300 Entertainment. YSL Records calls its roster of artists the "Slime Family." One of the YSL
|
||||
Records founder\'s charges includes conspiracy to violate the Racketeer Influenced and Corrupt Organizations Act (RICO).', name='duckduckgo_search', id='afc
|
||||
cee2f-f36f-4d86-8abd-8d05d733649d', tool_call_id='6d073578-5bf1-449e-8d32-4f3b5f999a06'), HumanMessage(content='Retrieved document is not sufficient or relev
|
||||
ant to answer the query. Reformulate the query to search knowledge base again.', id='22af0af8-92d0-4876-ad4e-16457c37f1f3'), AIMessage(content='', additional
|
||||
_kwargs={'tool_calls': [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'query': 'latest album by Young Thug'}, name=
|
||||
'duckduckgo_search', description=None), id='2315fb35-11ec-4a6b-be86-405a1efef411', type='function')]}, id='189ba925-9f82-4e9e-bdab-6dcfc49a196e', tool_calls=
|
||||
[{'name': 'duckduckgo_search', 'args': {'query': 'YSL Records founder name'}, 'id': 'cddeca43-88e2-4279-90aa-4c0b58b8b3c4', 'type': 'tool_call'}, {'name': 'd
|
||||
uckduckgo_search', 'args': {'query': 'latest album by Young Thug'}, 'id': '2315fb35-11ec-4a6b-be86-405a1efef411', 'type': 'tool_call'}]), ToolMessage(content
|
||||
='In 2016, Young Thug founded YSL Records, with its full name, "Young Stoner Life," perfectly representing the enigmatic persona of the label. YSL Records qu
|
||||
ickly became known for its ... Thug is also the founder of Young Stoner Life Records. However, the rapper was arrested in May 2022 and charged with conspirin
|
||||
g to violate the state\'s Racketeer Influenced and Corrupt Organizations (RICO) Act, according to. Prosecutors alleged that YSL was a street gang connected t
|
||||
o a sleuth of crimes. Prosecutors say YSL - the acronym for the artist\'s label, Young Stoner Life Records - also stands for Young Slime Life, an Atlanta-bas
|
||||
ed street gang affiliated with the national Bloods gang. Trontavious Stephens, a co-founder of YSL, took the witness stand for hours Wednesday and Thursday.
|
||||
Stephens testified that he had a criminal record, but he said he did not commit crimes with ... Young Thug\'s arrest sent shockwaves through the rap communit
|
||||
y. On May 9, 2022, the YSL rapper was apprehended along with 27 other alleged gang members as part of a sprawling 56-count indictment ...', name='duckduckgo_
|
||||
search', id='153464f0-aee1-4097-aae6-191ffa372aa6', tool_call_id='cddeca43-88e2-4279-90aa-4c0b58b8b3c4'), ToolMessage(content='The discography of American ra
|
||||
pper Young Thug consists of three studio albums, two compilation albums, twelve self-released mixtapes, seven commercial mixtapes, three extended plays, and
|
||||
sixty-nine singles (including 71 as a featured artist).. In 2015, Thug released his debut mixtape, Barter 6, which reached number 22 on the Billboard 200. Hi
|
||||
s 2016 mixtape, I\'m Up matched the same position. Atlanta singer (and Thugger\'s rumored girlfriend) also released her own "From a Woman" Friday. Even behin
|
||||
d bars, Young Thug remains prolific, with the rapper dropping his new single "From a Man ... Mariah The Scientist\'s new song "From A Woman" isn\'t musically
|
||||
connected to "From A Man," though Mariah\'s cover art works as a flipside to Thug\'s. She sings about finding someone ... Future, Young Thug & More Join Mus
|
||||
tard On New Album \'Faith Of A Mustard Seed\': Stream. News | Jul 28, 2024, 11:00 AM PDT. HipHopDX brings you all the newest Young Thug albums, songs, and ..
|
||||
. Find top songs and albums by Young Thug including Lifestyle (feat. Young Thug & Rich Homie Quan), It\'s Up and more. ... JEFFREY to the country-ish sides o
|
||||
f Beautiful Thugger Girls—has cracked the mainstream, laying the groundwork for a new crop of fellow eccentrics like Lil Uzi Vert and Playboi Carti. In other
|
||||
words, Thug hasn\'t adjusted ...', name='duckduckgo_search', id='79667433-d46e-4536-a04b-842223008bdc', tool_call_id='2315fb35-11ec-4a6b-be86-405a1efef411')
|
||||
, HumanMessage(content='Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again.', id='2f1
|
||||
b3773-8837-4f14-82b5-0a8bbccf3948'), HumanMessage(content='I don’t know.', id='a08edb17-a41a-458f-903e-0f7e15908e3f')]
|
||||
@@ -1,11 +1,11 @@
|
||||
# used by microservice
|
||||
docarray[full]
|
||||
|
||||
#used by tools
|
||||
duckduckgo-search
|
||||
fastapi
|
||||
huggingface_hub
|
||||
langchain
|
||||
|
||||
#used by tools
|
||||
langchain-google-community
|
||||
langchain-huggingface
|
||||
langchain-openai
|
||||
langchain_community
|
||||
|
||||
@@ -33,15 +33,5 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
|
||||
from .strategy.ragagent import RAGAgent
|
||||
|
||||
return RAGAgent(args, with_memory, custom_prompt=custom_prompt)
|
||||
elif strategy == "sql_agent_llama":
|
||||
print("Initializing SQL Agent Llama")
|
||||
from .strategy.sqlagent import SQLAgentLlama
|
||||
|
||||
return SQLAgentLlama(args, with_memory, custom_prompt=custom_prompt)
|
||||
elif strategy == "sql_agent":
|
||||
print("Initializing SQL Agent")
|
||||
from .strategy.sqlagent import SQLAgent
|
||||
|
||||
return SQLAgent(args, with_memory, custom_prompt=custom_prompt)
|
||||
else:
|
||||
raise ValueError(f"Agent strategy: {strategy} not supported!")
|
||||
|
||||
@@ -72,16 +72,3 @@ if os.environ.get("with_store") is not None:
|
||||
|
||||
if os.environ.get("timeout") is not None:
|
||||
env_config += ["--timeout", os.environ["timeout"]]
|
||||
|
||||
# for sql agent
|
||||
if os.environ.get("db_path") is not None:
|
||||
env_config += ["--db_path", os.environ["db_path"]]
|
||||
|
||||
if os.environ.get("db_name") is not None:
|
||||
env_config += ["--db_name", os.environ["db_name"]]
|
||||
|
||||
if os.environ.get("use_hints") is not None:
|
||||
env_config += ["--use_hints", os.environ["use_hints"]]
|
||||
|
||||
if os.environ.get("hints_file") is not None:
|
||||
env_config += ["--hints_file", os.environ["hints_file"]]
|
||||
|
||||
@@ -36,37 +36,5 @@ class BaseAgent:
|
||||
def execute(self, state: dict):
|
||||
pass
|
||||
|
||||
def prepare_initial_state(self, query):
|
||||
def non_streaming_run(self, query, config):
|
||||
raise NotImplementedError
|
||||
|
||||
async def stream_generator(self, query, config):
|
||||
initial_state = self.prepare_initial_state(query)
|
||||
try:
|
||||
async for event in self.app.astream(initial_state, config=config):
|
||||
for node_name, node_state in event.items():
|
||||
yield f"--- CALL {node_name} ---\n"
|
||||
for k, v in node_state.items():
|
||||
if v is not None:
|
||||
yield f"{k}: {v}\n"
|
||||
|
||||
yield f"data: {repr(event)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
yield str(e)
|
||||
|
||||
async def non_streaming_run(self, query, config):
|
||||
initial_state = self.prepare_initial_state(query)
|
||||
print("@@@ Initial State: ", initial_state)
|
||||
try:
|
||||
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
|
||||
message = s["messages"][-1]
|
||||
if isinstance(message, tuple):
|
||||
print(message)
|
||||
else:
|
||||
message.pretty_print()
|
||||
|
||||
last_message = s["messages"][-1]
|
||||
print("******Response: ", last_message.content)
|
||||
return last_message.content
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# SQL Agents
|
||||
|
||||
We currently have two types of SQL agents:
|
||||
|
||||
1. `sql_agent_llama`: for using with open-source LLMs, especially `meta-llama/Llama-3.1-70B-Instruct` model.
|
||||
2. `sql_agent`: for using with OpenAI models, we developed and validated with GPT-4o-mini.
|
||||
|
||||
## Overview of sql_agent_llama
|
||||
|
||||
The architecture of `sql_agent_llama` is shown in the figure below.
|
||||
The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem.
|
||||
|
||||

|
||||
|
||||
### Database schema:
|
||||
|
||||
We use langchain's [SQLDatabase](https://python.langchain.com/api_reference/community/utilities/langchain_community.utilities.sql_database.SQLDatabase.html#langchain_community.utilities.sql_database.SQLDatabase) API to get table names and schemas from the SQL database. User just need to specify `db_path` and `db_name`. The table schemas are incorporated into the prompts for the agent.
|
||||
|
||||
### Hints module:
|
||||
|
||||
If you want to use the hints module, you need to prepare a csv file that has 3 columns: `table_name`, `column_name`, `description`, and make this file available to the agent microservice. The `description` should include useful information (for example, domain knowledge) about a certain column in a table in the database. The hints module will pick up to five relevant columns together with their descriptions based on the user question using similarity search. The hints module will then pass these column descriptions to the agent node.
|
||||
|
||||
### Output parser:
|
||||
|
||||
Due to the current limitations of open source LLMs and serving frameworks (tgi and vllm) in generating tool call objects, we developed and optimized a custom output parser, together with our specially designed prompt templates. The output parser has 3 functions:
|
||||
|
||||
1. Decide if a valid final answer presents in the raw agent output. This is needed because: a) we found sometimes agent would make guess or hallucinate data, so it is critical to double check, b) sometimes LLM does not strictly follow instructions on output format so simple string parsing can fail. We use one additional LLM call to perform this function.
|
||||
2. Pick out tool calls from raw agent output. And check if the agent has made same tool calls before. If yes, remove the repeated tool calls.
|
||||
3. Parse and review SQL query, and fix SQL query if there are errors. This proved to improve SQL agent performance since the initial query may contain errors and having a "second pair of eyes" can often spot the errors while the agent node itself may not be able to identify the errors in subsequent execution steps.
|
||||
|
||||
## Overview of sql_agent
|
||||
|
||||
The architecture of `sql_agent` is shown in the figure below.
|
||||
The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem. The basic idea is the same as `sql_agent_llama`. However, since OpenAI APIs produce well-structured tool call objects, we don't need a special output parser. Instead, we only keep the query fixer.
|
||||
|
||||

|
||||
|
||||
## Limitations
|
||||
|
||||
1. Agent connects to local SQLite databases with uri.
|
||||
2. Agent is only allowed to issue "SELECT" commands to databases, i.e., agent can only query databases but cannot update databases.
|
||||
3. We currently does not support "streaming" agent outputs on the fly for `sql_agent_llama`.
|
||||
|
||||
Please submit issues if you want new features to be added. We also welcome community contributions!
|
||||
@@ -1,5 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .planner import SQLAgentLlama
|
||||
from .planner import SQLAgent
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def read_hints(hints_file):
|
||||
"""
|
||||
hints_file: csv with columns: table_name, column_name, description
|
||||
"""
|
||||
hints_df = pd.read_csv(hints_file)
|
||||
cols_descriptions = []
|
||||
values_descriptions = []
|
||||
for _, row in hints_df.iterrows():
|
||||
table_name = row["table_name"]
|
||||
col_name = row["column_name"]
|
||||
description = row["description"]
|
||||
if not pd.isnull(description):
|
||||
cols_descriptions.append(f"{table_name}.{col_name}: {description}")
|
||||
values_descriptions.append(f"{col_name}: {description}")
|
||||
return cols_descriptions, values_descriptions
|
||||
|
||||
|
||||
def sort_list(list1, list2):
|
||||
import numpy as np
|
||||
|
||||
# Use numpy's argsort function to get the indices that would sort the second list
|
||||
idx = np.argsort(list2) # ascending order
|
||||
return np.array(list1)[idx].tolist()[::-1], np.array(list2)[idx].tolist()[::-1] # descending order
|
||||
|
||||
|
||||
def get_topk_cols(topk, cols_descriptions, similarities):
|
||||
sorted_cols, similarities = sort_list(cols_descriptions, similarities)
|
||||
top_k_cols = sorted_cols[:topk]
|
||||
output = []
|
||||
for col, sim in zip(top_k_cols, similarities[:topk]):
|
||||
# print(f"{col}: {sim}")
|
||||
if sim > 0.5:
|
||||
output.append(col)
|
||||
return output
|
||||
|
||||
|
||||
def pick_hints(query, model, column_embeddings, complete_descriptions, topk=5):
|
||||
# use similarity to get the topk columns
|
||||
query_embedding = model.encode(query, convert_to_tensor=True)
|
||||
similarities = model.similarity(query_embedding, column_embeddings).flatten()
|
||||
|
||||
topk_cols_descriptions = get_topk_cols(topk, complete_descriptions, similarities)
|
||||
|
||||
hint = ""
|
||||
for col in topk_cols_descriptions:
|
||||
hint += col + "\n"
|
||||
return hint
|
||||
@@ -1,322 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Sequence, TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.managed import IsLastStep
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from ...utils import setup_chat_model, tool_renderer
|
||||
from ..base_agent import BaseAgent
|
||||
from .hint import pick_hints, read_hints
|
||||
from .prompt import AGENT_NODE_TEMPLATE, AGENT_SYSM, QUERYFIXER_PROMPT
|
||||
from .sql_tools import get_sql_query_tool, get_table_schema
|
||||
from .utils import (
|
||||
LlamaOutputParserAndQueryFixer,
|
||||
assemble_history,
|
||||
convert_json_to_tool_call,
|
||||
remove_repeated_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
"""The state of the agent."""
|
||||
|
||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
is_last_step: IsLastStep
|
||||
hint: str
|
||||
|
||||
|
||||
class AgentNodeLlama:
|
||||
def __init__(self, args, tools):
|
||||
self.llm = setup_chat_model(args)
|
||||
self.args = args
|
||||
# two types of tools:
|
||||
# sql_db_query - always available, no need to specify
|
||||
# other tools - user defined
|
||||
# here, self.tools is a list of user defined tools
|
||||
self.tools = tool_renderer(tools)
|
||||
print("@@@@ Tools: ", self.tools)
|
||||
|
||||
self.chain = self.llm
|
||||
|
||||
self.output_parser = LlamaOutputParserAndQueryFixer(chat_model=self.llm)
|
||||
|
||||
if args.use_hints:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.cols_descriptions, self.values_descriptions = read_hints(args.hints_file)
|
||||
self.embed_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
||||
self.column_embeddings = self.embed_model.encode(self.values_descriptions)
|
||||
|
||||
def __call__(self, state):
|
||||
print("----------Call Agent Node----------")
|
||||
question = state["messages"][0].content
|
||||
table_schema, num_tables = get_table_schema(self.args.db_path)
|
||||
if self.args.use_hints:
|
||||
if not state["hint"]:
|
||||
hints = pick_hints(question, self.embed_model, self.column_embeddings, self.cols_descriptions)
|
||||
else:
|
||||
hints = state["hint"]
|
||||
print("@@@ Hints: ", hints)
|
||||
else:
|
||||
hints = ""
|
||||
|
||||
history = assemble_history(state["messages"])
|
||||
print("@@@ History: ", history)
|
||||
|
||||
prompt = AGENT_NODE_TEMPLATE.format(
|
||||
domain=self.args.db_name,
|
||||
tools=self.tools,
|
||||
num_tables=num_tables,
|
||||
tables_schema=table_schema,
|
||||
question=question,
|
||||
hints=hints,
|
||||
history=history,
|
||||
)
|
||||
|
||||
output = self.chain.invoke(prompt)
|
||||
output = self.output_parser.parse(
|
||||
output.content, history, table_schema, hints, question, state["messages"]
|
||||
) # text: str, history: str, db_schema: str, hint: str
|
||||
print("@@@@@ Agent output:\n", output)
|
||||
|
||||
# convert output to tool calls
|
||||
tool_calls = []
|
||||
for res in output:
|
||||
if "tool" in res:
|
||||
tool_call = convert_json_to_tool_call(res)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
# check if same tool calls have been made before
|
||||
# if yes, then remove the repeated tool calls
|
||||
if tool_calls:
|
||||
new_tool_calls = remove_repeated_tool_calls(tool_calls, state["messages"])
|
||||
print("@@@@ New Tool Calls:\n", new_tool_calls)
|
||||
else:
|
||||
new_tool_calls = []
|
||||
|
||||
if new_tool_calls:
|
||||
ai_message = AIMessage(content="", tool_calls=new_tool_calls)
|
||||
elif tool_calls:
|
||||
ai_message = AIMessage(content="Repeated previous steps.", tool_calls=tool_calls)
|
||||
elif "answer" in output[0]:
|
||||
ai_message = AIMessage(content=str(output[0]["answer"]))
|
||||
else:
|
||||
ai_message = AIMessage(content=str(output))
|
||||
|
||||
return {"messages": [ai_message], "hint": hints}
|
||||
|
||||
|
||||
class SQLAgentLlama(BaseAgent):
|
||||
# need new args:
|
||||
# # db_name and db_path
|
||||
# # use_hints, hints_file
|
||||
def __init__(self, args, with_memory=False, **kwargs):
|
||||
super().__init__(args, local_vars=globals(), **kwargs)
|
||||
# note: here tools only include user defined tools
|
||||
# we need to add the sql query tool as well
|
||||
print("@@@@ user defined tools: ", self.tools_descriptions)
|
||||
agent = AgentNodeLlama(args, self.tools_descriptions)
|
||||
sql_tool = get_sql_query_tool(args.db_path)
|
||||
print("@@@@ SQL Tool: ", sql_tool)
|
||||
tools = self.tools_descriptions + [sql_tool]
|
||||
print("@@@@ ALL Tools: ", tools)
|
||||
tool_node = ToolNode(tools)
|
||||
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Define the nodes we will cycle between
|
||||
workflow.add_node("agent", agent)
|
||||
workflow.add_node("tools", tool_node)
|
||||
|
||||
workflow.set_entry_point("agent")
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
self.decide_next_step,
|
||||
{
|
||||
# If `tools`, then we call the tool node.
|
||||
"tools": "tools",
|
||||
"agent": "agent",
|
||||
"end": END,
|
||||
},
|
||||
)
|
||||
|
||||
# We now add a normal edge from `tools` to `agent`.
|
||||
# This means that after `tools` is called, `agent` node is called next.
|
||||
workflow.add_edge("tools", "agent")
|
||||
|
||||
self.app = workflow.compile()
|
||||
|
||||
def decide_next_step(self, state: AgentState):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls and last_message.content == "Repeated previous steps.":
|
||||
print("@@@@ Repeated tool calls from previous steps, go back to agent")
|
||||
return "agent"
|
||||
elif last_message.tool_calls and last_message.content != "Repeated previous steps.":
|
||||
print("@@@@ New Tool calls, go to tools")
|
||||
return "tools"
|
||||
else:
|
||||
return "end"
|
||||
|
||||
def prepare_initial_state(self, query):
|
||||
return {"messages": [HumanMessage(content=query)], "is_last_step": IsLastStep(False), "hint": ""}
|
||||
|
||||
|
||||
################################################
|
||||
# Below is SQL agent using OpenAI models
|
||||
################################################
|
||||
class AgentNode:
|
||||
def __init__(self, args, llm, tools):
|
||||
self.llm = llm.bind_tools(tools)
|
||||
self.args = args
|
||||
if args.use_hints:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.cols_descriptions, self.values_descriptions = read_hints(args.hints_file)
|
||||
self.embed_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
||||
self.column_embeddings = self.embed_model.encode(self.values_descriptions)
|
||||
|
||||
def __call__(self, state):
|
||||
print("----------Call Agent Node----------")
|
||||
question = state["messages"][0].content
|
||||
table_schema, num_tables = get_table_schema(self.args.db_path)
|
||||
if self.args.use_hints:
|
||||
if not state["hint"]:
|
||||
hints = pick_hints(question, self.embed_model, self.column_embeddings, self.cols_descriptions)
|
||||
else:
|
||||
hints = state["hint"]
|
||||
else:
|
||||
hints = ""
|
||||
|
||||
sysm = AGENT_SYSM.format(num_tables=num_tables, tables_schema=table_schema, question=question, hints=hints)
|
||||
_system_message = SystemMessage(content=sysm)
|
||||
state_modifier_runnable = RunnableLambda(
|
||||
lambda state: [_system_message] + state["messages"],
|
||||
name="StateModifier",
|
||||
)
|
||||
|
||||
chain = state_modifier_runnable | self.llm
|
||||
response = chain.invoke(state)
|
||||
|
||||
return {"messages": [response], "hint": hints}
|
||||
|
||||
|
||||
class QueryFixerNode:
|
||||
def __init__(self, args, llm):
|
||||
prompt = PromptTemplate(
|
||||
template=QUERYFIXER_PROMPT,
|
||||
input_variables=["DATABASE_SCHEMA", "QUESTION", "HINT", "QUERY", "RESULT"],
|
||||
)
|
||||
self.chain = prompt | llm
|
||||
self.args = args
|
||||
|
||||
def get_sql_query_and_result(self, state):
|
||||
messages = state["messages"]
|
||||
assert isinstance(messages[-1], ToolMessage), "The last message should be a tool message"
|
||||
result = messages[-1].content
|
||||
id = messages[-1].tool_call_id
|
||||
query = ""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
if msg.tool_calls[0]["id"] == id:
|
||||
query = msg.tool_calls[0]["args"]["query"]
|
||||
break
|
||||
print("@@@@ Executed SQL Query: ", query)
|
||||
print("@@@@ Execution Result: ", result)
|
||||
return query, result
|
||||
|
||||
def __call__(self, state):
|
||||
print("----------Call Query Fixer Node----------")
|
||||
table_schema, _ = get_table_schema(self.args.db_path)
|
||||
question = state["messages"][0].content
|
||||
hint = state["hint"]
|
||||
query, result = self.get_sql_query_and_result(state)
|
||||
response = self.chain.invoke(
|
||||
{
|
||||
"DATABASE_SCHEMA": table_schema,
|
||||
"QUESTION": question,
|
||||
"HINT": hint,
|
||||
"QUERY": query,
|
||||
"RESULT": result,
|
||||
}
|
||||
)
|
||||
# print("@@@@@ Query fixer output:\n", response.content)
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
class SQLAgent(BaseAgent):
|
||||
def __init__(self, args, with_memory=False, **kwargs):
|
||||
super().__init__(args, local_vars=globals(), **kwargs)
|
||||
|
||||
sql_tool = get_sql_query_tool(args.db_path)
|
||||
tools = self.tools_descriptions + [sql_tool]
|
||||
print("@@@@ ALL Tools: ", tools)
|
||||
|
||||
tool_node = ToolNode(tools)
|
||||
agent = AgentNode(args, self.llm, tools)
|
||||
query_fixer = QueryFixerNode(args, self.llm)
|
||||
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Define the nodes we will cycle between
|
||||
workflow.add_node("agent", agent)
|
||||
workflow.add_node("query_fixer", query_fixer)
|
||||
workflow.add_node("tools", tool_node)
|
||||
|
||||
workflow.set_entry_point("agent")
|
||||
|
||||
# We now add a conditional edge
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
self.should_continue,
|
||||
{
|
||||
# If `tools`, then we call the tool node.
|
||||
"continue": "tools",
|
||||
"end": END,
|
||||
},
|
||||
)
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"tools",
|
||||
self.should_go_to_query_fixer,
|
||||
{"true": "query_fixer", "false": "agent"},
|
||||
)
|
||||
workflow.add_edge("query_fixer", "agent")
|
||||
|
||||
self.app = workflow.compile()
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def should_continue(self, state: AgentState):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
# If there is no function call, then we finish
|
||||
if not last_message.tool_calls:
|
||||
return "end"
|
||||
# Otherwise if there is, we continue
|
||||
else:
|
||||
return "continue"
|
||||
|
||||
def should_go_to_query_fixer(self, state: AgentState):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
assert isinstance(last_message, ToolMessage), "The last message should be a tool message"
|
||||
print("@@@@ Called Tool: ", last_message.name)
|
||||
if last_message.name == "sql_db_query":
|
||||
print("@@@@ Going to Query Fixer")
|
||||
return "true"
|
||||
else:
|
||||
print("@@@@ Going back to Agent")
|
||||
return "false"
|
||||
|
||||
def prepare_initial_state(self, query):
|
||||
return {"messages": [HumanMessage(content=query)], "is_last_step": IsLastStep(False), "hint": ""}
|
||||
@@ -1,225 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
AGENT_NODE_TEMPLATE = """\
|
||||
You are an SQL expert tasked with answering questions about {domain}.
|
||||
In addition to the database, you have the following tools to gather information:
|
||||
{tools}
|
||||
|
||||
You can access a database that has {num_tables} tables. The schema of the tables is as follows. Read the schema carefully.
|
||||
**Table Schema:**
|
||||
{tables_schema}
|
||||
|
||||
**Hints:**
|
||||
{hints}
|
||||
|
||||
When querying the database, remember the following:
|
||||
1. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to no more than 20 results.
|
||||
2. Only query columns that are relevant to the question. Remember to also fetch the ranking or filtering columns to check if they contain nulls.
|
||||
3. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
|
||||
**Output format:**
|
||||
1. Write down your thinking process.
|
||||
2. When querying the database, write your SQL query in the following format:
|
||||
```sql
|
||||
SELECT column1, column2, ...
|
||||
```
|
||||
3. When making tool calls, you must use the following format. Make ONLY one tool call at a time.
|
||||
TOOL CALL: {{"tool": "tool1", "args": {{"arg1": "value1", "arg2": "value2", ...}}}}
|
||||
|
||||
4. After you have arrived at the answer with data and reasoning, write your final answer after "FINAL ANSWER".
|
||||
|
||||
You have done the following steps so far:
|
||||
**Your previous steps:**
|
||||
{history}
|
||||
|
||||
**IMPORTANT:**
|
||||
* Review your previous steps carefully and utilize them to answer the question. Do not repeat your previous steps.
|
||||
* The database may not have all the information needed to answer the question. Use the additional tools provided if necessary.
|
||||
* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box.
|
||||
|
||||
Now take a deep breath and think step by step to answeer the following question.
|
||||
Question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
|
||||
ANSWER_PARSER_PROMPT = """\
|
||||
Review the output from an SQL agent and determine if a correct answer has been provided and grounded on real data.
|
||||
|
||||
Say "yes" when all the following conditions are met:
|
||||
1. The answer is complete and does not require additional steps to be taken.
|
||||
2. The answer does not have placeholders that need to be filled in.
|
||||
3. The agent has acquired data from database and its execution history is Not empty.
|
||||
4. If agent made mistakes in its execution history, the agent has corrected them.
|
||||
5. If agent has tried to get data several times but cannot get all the data needed, the agent has come up with an answer based on available data and reasonable assumptions.
|
||||
|
||||
If the conditions above are not met, say "no".
|
||||
|
||||
Here is the output from the SQL agent:
|
||||
{output}
|
||||
======================
|
||||
Here is the agent execution history:
|
||||
{history}
|
||||
======================
|
||||
|
||||
Has a final answer been provided based on real data? Analyze the agent output and make your judgement "yes" or "no".
|
||||
"""
|
||||
|
||||
|
||||
SQL_QUERY_FIXER_PROMPT = """\
|
||||
You are an SQL database expert tasked with reviewing a SQL query written by an agent.
|
||||
**Procedure:**
|
||||
1. Review Database Schema:
|
||||
- Examine the table creation statements to understand the database structure.
|
||||
2. Review the Hint provided.
|
||||
- Use the provided hints to understand the domain knowledge relevant to the query.
|
||||
3. Check against the following common errors:
|
||||
- Failure to exclude null values, ranking or filtering columns have nulls, syntax errors, incorrect table references, incorrect column references, logical mistakes.
|
||||
4. Check if aggregation should be used:
|
||||
- Read the user question, and determine if user is asking for specific instances or aggregated info. If aggregation is needed, check if the original SQL query has used appropriate functions like COUNT and SUM.
|
||||
5. Correct the Query only when Necessary:
|
||||
- If issues were identified, modify the SQL query to address the identified issues, ensuring it correctly fetches the requested data according to the database schema and query requirements.
|
||||
|
||||
======= Your task =======
|
||||
**************************
|
||||
Table creation statements
|
||||
{DATABASE_SCHEMA}
|
||||
**************************
|
||||
Hint:
|
||||
{HINT}
|
||||
**************************
|
||||
The SQL query to review:
|
||||
{QUERY}
|
||||
**************************
|
||||
User question:
|
||||
{QUESTION}
|
||||
**************************
|
||||
|
||||
Now analyze the SQL query step by step. Present your reasonings.
|
||||
|
||||
If you identified issues in the original query, write down the corrected SQL query in the format below:
|
||||
```sql
|
||||
SELECT column1, column2, ...
|
||||
```
|
||||
|
||||
If the original SQL query is correct, just say the query is correct.
|
||||
|
||||
Note: Some user questions can only be answered partially with the database. This is OK. The agent may use other tools in subsequent steps to get additional info. In some cases, the agent may have got additional info with other tools and have incorporated those in its query. Your goal is to review the SQL query and fix it when necessary.
|
||||
Only use the tables provided in the database schema in your corrected query. Do not join tables that are not present in the schema. Do not create any new tables.
|
||||
If you cannot do better than the original query, just say the query is correct.
|
||||
"""
|
||||
|
||||
SQL_QUERY_FIXER_PROMPT_with_result = """\
|
||||
You are an SQL database expert tasked with reviewing a SQL query.
|
||||
**Procedure:**
|
||||
1. Review Database Schema:
|
||||
- Examine the table creation statements to understand the database structure.
|
||||
2. Review the Hint provided.
|
||||
- Use the provided hints to understand the domain knowledge relevant to the query.
|
||||
3. Analyze Query Requirements:
|
||||
- User Question: Consider what information the query is supposed to retrieve. Decide if aggregation like COUNT or SUM is needed.
|
||||
- Executed SQL Query: Review the SQL query that was previously executed.
|
||||
- Execution Result: Analyze the outcome of the executed query. Think carefully if the result makes sense.
|
||||
4. Check against the following common errors:
|
||||
- Failure to exclude null values, ranking or filtering columns have nulls, syntax errors, incorrect table references, incorrect column references, logical mistakes.
|
||||
5. Correct the Query only when Necessary:
|
||||
- If issues were identified, modify the SQL query to address the identified issues, ensuring it correctly fetches the requested data according to the database schema and query requirements.
|
||||
|
||||
======= Your task =======
|
||||
**************************
|
||||
Table creation statements
|
||||
{DATABASE_SCHEMA}
|
||||
**************************
|
||||
Hint:
|
||||
{HINT}
|
||||
**************************
|
||||
User Question:
|
||||
{QUESTION}
|
||||
**************************
|
||||
The SQL query executed was:
|
||||
{QUERY}
|
||||
**************************
|
||||
The execution result:
|
||||
{RESULT}
|
||||
**************************
|
||||
|
||||
Now analyze the SQL query step by step. Present your reasonings.
|
||||
|
||||
If you identified issues in the original query, write down the corrected SQL query in the format below:
|
||||
```sql
|
||||
SELECT column1, column2, ...
|
||||
```
|
||||
|
||||
If the original SQL query is correct, just say the query is correct.
|
||||
|
||||
Note: Some user questions can only be answered partially with the database. This is OK. The agent may use other tools in subsequent steps to get additional info. In some cases, the agent may have got additional info with other tools and have incorporated those in its query. Your goal is to review the SQL query and fix it when necessary.
|
||||
Only use the tables provided in the database schema in your corrected query. Do not join tables that are not present in the schema. Do not create any new tables.
|
||||
If you cannot do better than the original query, just say the query is correct.
|
||||
"""
|
||||
|
||||
|
||||
##########################################
|
||||
## Prompt templates for SQL agent using OpenAI models
|
||||
##########################################
|
||||
AGENT_SYSM = """\
|
||||
You are an SQL expert tasked with answering questions about schools in California.
|
||||
You can access a database that has {num_tables} tables. The schema of the tables is as follows. Read the schema carefully.
|
||||
{tables_schema}
|
||||
****************
|
||||
Question: {question}
|
||||
|
||||
Hints:
|
||||
{hints}
|
||||
****************
|
||||
|
||||
When querying the database, remember the following:
|
||||
1. You MUST double check your SQL query before executing it. Reflect on the steps you have taken and fix errors if there are any. If you get an error while executing a query, rewrite the query and try again.
|
||||
2. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to no more than 20 results.
|
||||
3. Only query columns that are relevant to the question.
|
||||
4. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
|
||||
IMPORTANT:
|
||||
* Divide the question into sub-questions and conquer sub-questions one by one.
|
||||
* You may need to combine information from multiple tables to answer the question.
|
||||
* If database does not have all the information needed to answer the question, use the web search tool or your own knowledge.
|
||||
* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box. You hard work will be rewarded.
|
||||
|
||||
Now take a deep breath and think step by step to solve the problem.
|
||||
"""
|
||||
|
||||
QUERYFIXER_PROMPT = """\
|
||||
You are an SQL database expert tasked with reviewing a SQL query.
|
||||
**Procedure:**
|
||||
1. Review Database Schema:
|
||||
- Examine the table creation statements to understand the database structure.
|
||||
2. Review the Hint provided.
|
||||
- Use the provided hints to understand the domain knowledge relevant to the query.
|
||||
3. Analyze Query Requirements:
|
||||
- Original Question: Consider what information the query is supposed to retrieve.
|
||||
- Executed SQL Query: Review the SQL query that was previously executed.
|
||||
- Execution Result: Analyze the outcome of the executed query. Think carefully if the result makes sense. If the result does not make sense, identify the issues with the executed SQL query (e.g., null values, syntax
|
||||
errors, incorrect table references, incorrect column references, logical mistakes).
|
||||
4. Correct the Query if Necessary:
|
||||
- If issues were identified, modify the SQL query to address the identified issues, ensuring it correctly fetches the requested data
|
||||
according to the database schema and query requirements.
|
||||
5. If the query is correct, provide the same query as the final answer.
|
||||
|
||||
======= Your task =======
|
||||
**************************
|
||||
Table creation statements
|
||||
{DATABASE_SCHEMA}
|
||||
**************************
|
||||
Hint:
|
||||
{HINT}
|
||||
**************************
|
||||
The original question is:
|
||||
Question:
|
||||
{QUESTION}
|
||||
The SQL query executed was:
|
||||
{QUERY}
|
||||
The execution result:
|
||||
{RESULT}
|
||||
**************************
|
||||
Based on the question, table schema, hint and the previous query, analyze the result. Fix the query if needed and provide your reasoning. If the query is correct, provide the same query as the final answer.
|
||||
"""
|
||||
@@ -1,32 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
|
||||
def connect_to_db(db_path):
|
||||
uri = "sqlite:///{path}".format(path=db_path)
|
||||
db = SQLDatabase.from_uri(uri)
|
||||
return db
|
||||
|
||||
|
||||
def get_table_schema(db_path):
|
||||
db = connect_to_db(db_path)
|
||||
table_names = ", ".join(db.get_usable_table_names())
|
||||
num_tables = len(table_names.split(","))
|
||||
schema = db.get_table_info_no_throw([t.strip() for t in table_names.split(",")])
|
||||
return schema, num_tables
|
||||
|
||||
|
||||
def get_sql_query_tool(db_path):
|
||||
db = connect_to_db(db_path)
|
||||
query_sql_database_tool_description = (
|
||||
"Input to this tool is a detailed and correct SQL query, output is a "
|
||||
"result from the database. If the query is not correct, an error message "
|
||||
"will be returned. If an error is returned, rewrite the query, check the "
|
||||
"query, and try again. "
|
||||
)
|
||||
db_query_tool = QuerySQLDataBaseTool(db=db, name="sql_db_query", description=query_sql_database_tool_description)
|
||||
print("SQL Query Tool Created: ", db_query_tool)
|
||||
return db_query_tool
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
|
||||
from .prompt import ANSWER_PARSER_PROMPT, SQL_QUERY_FIXER_PROMPT, SQL_QUERY_FIXER_PROMPT_with_result
|
||||
|
||||
|
||||
def parse_answer_with_llm(text, history, chat_model):
|
||||
if "FINAL ANSWER:" in text.upper():
|
||||
if history == "":
|
||||
history = "The agent execution history is empty."
|
||||
|
||||
prompt = ANSWER_PARSER_PROMPT.format(output=text, history=history)
|
||||
response = chat_model.invoke(prompt).content
|
||||
print("@@@ Answer parser response: ", response)
|
||||
|
||||
temp = response[:5]
|
||||
if "yes" in temp.lower():
|
||||
return text.split("FINAL ANSWER:")[-1]
|
||||
else:
|
||||
temp = response.split("\n")[0]
|
||||
if "yes" in temp.lower():
|
||||
return text.split("FINAL ANSWER:")[-1]
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_tool_calls_other_than_sql(text):
|
||||
"""Get the tool calls other than sql_db_query."""
|
||||
tool_calls = []
|
||||
text = text.replace("assistant", "")
|
||||
json_lines = text.split("\n")
|
||||
# only get the unique lines
|
||||
json_lines = list(set(json_lines))
|
||||
for line in json_lines:
|
||||
if "TOOL CALL:" in line:
|
||||
if "sql_db_query" not in line:
|
||||
line = line.replace("TOOL CALL:", "")
|
||||
if "assistant" in line:
|
||||
line = line.replace("assistant", "")
|
||||
if "\\" in line:
|
||||
line = line.replace("\\", "")
|
||||
try:
|
||||
parsed_line = json.loads(line)
|
||||
if isinstance(parsed_line, dict):
|
||||
if "tool" in parsed_line:
|
||||
tool_calls.append(parsed_line)
|
||||
|
||||
except:
|
||||
pass
|
||||
return tool_calls
|
||||
|
||||
|
||||
def get_all_sql_queries(text):
|
||||
queries = []
|
||||
if "```sql" in text:
|
||||
temp = text.split("```sql")
|
||||
for t in temp:
|
||||
if "```" in t:
|
||||
query = t.split("```")[0]
|
||||
if "SELECT" in query.upper() and "TOOL CALL" not in query.upper():
|
||||
queries.append(query)
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
def get_the_last_sql_query(text):
|
||||
queries = get_all_sql_queries(text)
|
||||
if queries:
|
||||
return queries[-1]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_query_if_executed_and_result(query, messages):
|
||||
# get previous sql_db_query tool calls
|
||||
previous_tool_calls = []
|
||||
for m in messages:
|
||||
if isinstance(m, AIMessage) and m.tool_calls:
|
||||
for tc in m.tool_calls:
|
||||
if tc["name"] == "sql_db_query":
|
||||
previous_tool_calls.append(tc)
|
||||
for tc in previous_tool_calls:
|
||||
if query == tc["args"]["query"]:
|
||||
return get_tool_output(messages, tc["id"])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_and_fix_sql_query_v2(text, chat_model, db_schema, hint, question, messages):
|
||||
chosen_query = get_the_last_sql_query(text)
|
||||
if chosen_query:
|
||||
# check if the query has been executed before
|
||||
# if yes, pass execution result to the fixer
|
||||
# if not, pass only the query to the fixer
|
||||
result = check_query_if_executed_and_result(chosen_query, messages)
|
||||
if result:
|
||||
prompt = SQL_QUERY_FIXER_PROMPT_with_result.format(
|
||||
DATABASE_SCHEMA=db_schema, HINT=hint, QUERY=chosen_query, QUESTION=question, RESULT=result
|
||||
)
|
||||
else:
|
||||
prompt = SQL_QUERY_FIXER_PROMPT.format(
|
||||
DATABASE_SCHEMA=db_schema, HINT=hint, QUERY=chosen_query, QUESTION=question
|
||||
)
|
||||
|
||||
response = chat_model.invoke(prompt).content
|
||||
print("@@@ SQL query fixer response: ", response)
|
||||
if "query is correct" in response.lower():
|
||||
return chosen_query
|
||||
else:
|
||||
# parse the fixed query
|
||||
fixed_query = get_the_last_sql_query(response)
|
||||
return fixed_query
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class LlamaOutputParserAndQueryFixer:
|
||||
def __init__(self, chat_model):
|
||||
self.chat_model = chat_model
|
||||
|
||||
def parse(self, text: str, history: str, db_schema: str, hint: str, question: str, messages: list):
|
||||
print("@@@ Raw output from llm:\n", text)
|
||||
answer = parse_answer_with_llm(text, history, self.chat_model)
|
||||
if answer:
|
||||
print("Final answer exists.")
|
||||
return answer
|
||||
else:
|
||||
tool_calls = get_tool_calls_other_than_sql(text)
|
||||
sql_query = parse_and_fix_sql_query_v2(text, self.chat_model, db_schema, hint, question, messages)
|
||||
if sql_query:
|
||||
sql_tool_call = [{"tool": "sql_db_query", "args": {"query": sql_query}}]
|
||||
tool_calls.extend(sql_tool_call)
|
||||
if tool_calls:
|
||||
return tool_calls
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def convert_json_to_tool_call(json_str):
|
||||
tool_name = json_str["tool"]
|
||||
tool_args = json_str["args"]
|
||||
tcid = str(uuid.uuid4())
|
||||
tool_call = ToolCall(name=tool_name, args=tool_args, id=tcid)
|
||||
return tool_call
|
||||
|
||||
|
||||
def get_tool_output(messages, id):
|
||||
tool_output = ""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, ToolMessage):
|
||||
if msg.tool_call_id == id:
|
||||
tool_output = msg.content
|
||||
tool_output = tool_output[:1000] # limit to 1000 characters
|
||||
break
|
||||
return tool_output
|
||||
|
||||
|
||||
def assemble_history(messages):
|
||||
"""
|
||||
messages: AI, TOOL, AI, TOOL, etc.
|
||||
"""
|
||||
query_history = ""
|
||||
breaker = "-" * 10
|
||||
n = 1
|
||||
for m in messages[1:]: # exclude the first message
|
||||
if isinstance(m, AIMessage):
|
||||
# if there is tool call
|
||||
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0 and m.content != "Repeated previous steps.":
|
||||
for tool_call in m.tool_calls:
|
||||
tool = tool_call["name"]
|
||||
tc_args = tool_call["args"]
|
||||
id = tool_call["id"]
|
||||
tool_output = get_tool_output(messages, id)
|
||||
if tool == "sql_db_query":
|
||||
sql_query = tc_args["query"]
|
||||
query_history += (
|
||||
f"Step {n}. Executed SQL query: {sql_query}\nQuery Result: {tool_output}\n{breaker}\n"
|
||||
)
|
||||
else:
|
||||
query_history += (
|
||||
f"Step {n}. Called tool: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n"
|
||||
)
|
||||
n += 1
|
||||
elif m.content == "Repeated previous steps.": # repeated steps
|
||||
query_history += f"Step {n}. Repeated tool calls from previous steps.\n{breaker}\n"
|
||||
n += 1
|
||||
else:
|
||||
# did not make tool calls
|
||||
query_history += f"Assistant Output: {m.content}\n"
|
||||
|
||||
return query_history
|
||||
|
||||
|
||||
def remove_repeated_tool_calls(tool_calls, messages):
|
||||
"""Remove repeated tool calls in the messages.
|
||||
|
||||
tool_calls: list of tool calls: ToolCall(name=tool_name, args=tool_args, id=tcid)
|
||||
messages: list of messages: AIMessage, ToolMessage, HumanMessage
|
||||
"""
|
||||
# first get all the previous tool calls in messages
|
||||
previous_tool_calls = []
|
||||
for m in messages:
|
||||
if isinstance(m, AIMessage) and m.tool_calls and m.content != "Repeated previous steps.":
|
||||
for tc in m.tool_calls:
|
||||
previous_tool_calls.append({"tool": tc["name"], "args": tc["args"]})
|
||||
|
||||
unique_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
if {"tool": tc["name"], "args": tc["args"]} not in previous_tool_calls:
|
||||
unique_tool_calls.append(tc)
|
||||
|
||||
return unique_tool_calls
|
||||
@@ -139,14 +139,8 @@ def get_args():
|
||||
parser.add_argument("--with_store", type=bool, default=False)
|
||||
parser.add_argument("--timeout", type=int, default=60)
|
||||
|
||||
# for sql agent
|
||||
parser.add_argument("--db_path", type=str, help="database path")
|
||||
parser.add_argument("--db_name", type=str, help="database name")
|
||||
parser.add_argument("--use_hints", type=str, default="false", help="If this agent uses hints")
|
||||
parser.add_argument("--hints_file", type=str, help="path to the hints file")
|
||||
|
||||
sys_args, unknown_args = parser.parse_known_args()
|
||||
print("env_config: ", env_config)
|
||||
# print("env_config: ", env_config)
|
||||
if env_config != []:
|
||||
env_args, env_unknown_args = parser.parse_known_args(env_config)
|
||||
unknown_args += env_unknown_args
|
||||
@@ -157,12 +151,5 @@ def get_args():
|
||||
sys_args.streaming = True
|
||||
else:
|
||||
sys_args.streaming = False
|
||||
|
||||
if sys_args.use_hints == "true":
|
||||
print("SQL agent will use hints")
|
||||
sys_args.use_hints = True
|
||||
else:
|
||||
sys_args.use_hints = False
|
||||
|
||||
print("==========sys_args==========:\n", sys_args)
|
||||
return sys_args, unknown_args
|
||||
|
||||
Reference in New Issue
Block a user