Add Finance Agent Example (#1752)
Signed-off-by: minmin-intel <minmin.hou@intel.com> Signed-off-by: Rita Brugarolas <rita.brugarolas.brufau@intel.com> Signed-off-by: rbrugaro <rita.brugarolas.brufau@intel.com> Co-authored-by: rbrugaro <rita.brugarolas.brufau@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: lkk <33276950+lkk12014402@users.noreply.github.com> Co-authored-by: lkk12014402 <kaokao.lv@intel.com>
This commit is contained in:
359
FinanceAgent/tools/utils.py
Normal file
359
FinanceAgent/tools/utils.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright (C) 2025 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
||||
from langchain_redis import RedisConfig, RedisVectorStore
|
||||
from openai import OpenAI
|
||||
|
||||
try:
|
||||
from tools.redis_kv import RedisKVStore
|
||||
except ImportError:
|
||||
from redis_kv import RedisKVStore
|
||||
|
||||
# Embedding model
|
||||
EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-base-en-v1.5")
|
||||
TEI_EMBEDDING_ENDPOINT = os.getenv("TEI_EMBEDDING_ENDPOINT", "")
|
||||
|
||||
# Redis URL
|
||||
REDIS_URL_VECTOR = os.getenv("REDIS_URL_VECTOR", "redis://localhost:6379/")
|
||||
REDIS_URL_KV = os.getenv("REDIS_URL_KV", "redis://localhost:6380/")
|
||||
|
||||
# LLM config
|
||||
LLM_MODEL = os.getenv("model", "meta-llama/Llama-3.3-70B-Instruct")
|
||||
LLM_ENDPOINT = os.getenv("llm_endpoint_url", "http://localhost:8086")
|
||||
print(f"LLM endpoint: {LLM_ENDPOINT}")
|
||||
MAX_TOKENS = 1024
|
||||
TEMPERATURE = 0.2
|
||||
|
||||
COMPANY_NAME_PROMPT = """\
|
||||
Here is the list of company names in the knowledge base:
|
||||
{company_list}
|
||||
|
||||
This is the company of interest: {company}
|
||||
|
||||
Determine if the company of interest is the same as any of the companies in the knowledge base.
|
||||
If yes, map the company of interest to the company name in the knowledge base. Output the company name in {{}}. Example: {{3M}}.
|
||||
If none of the companies in the knowledge base match the company of interest, output "NONE".
|
||||
"""
|
||||
|
||||
ANSWER_PROMPT = """\
|
||||
You are a financial analyst. Read the documents below and answer the question.
|
||||
Documents:
|
||||
{documents}
|
||||
|
||||
Question: {query}
|
||||
Now take a deep breath and think step by step to answer the question. Wrap your final answer in {{}}. Example: {{The company has a revenue of $100 million.}}
|
||||
"""
|
||||
|
||||
|
||||
def format_company_name(company):
|
||||
company = company.upper()
|
||||
|
||||
# decide if company is in company list
|
||||
company_list = get_company_list()
|
||||
print(f"company_list {company_list}")
|
||||
company = get_company_name_in_kb(company, company_list)
|
||||
if "Cannot find" in company or "Database is empty" in company:
|
||||
raise ValueError(f"Company not found in knowledge base: {company}")
|
||||
print(f"Company: {company}")
|
||||
return company
|
||||
|
||||
|
||||
def get_embedder():
|
||||
if TEI_EMBEDDING_ENDPOINT:
|
||||
# create embeddings using TEI endpoint service
|
||||
# Huggingface API token for TEI embedding endpoint
|
||||
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "")
|
||||
assert HUGGINGFACEHUB_API_TOKEN, "HuggingFace API token is required for TEI embedding endpoint."
|
||||
embedder = HuggingFaceEndpointEmbeddings(model=TEI_EMBEDDING_ENDPOINT)
|
||||
else:
|
||||
# create embeddings using local embedding model
|
||||
embedder = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL)
|
||||
return embedder
|
||||
|
||||
|
||||
def generate_answer(prompt):
|
||||
"""Use vllm endpoint to generate the answer."""
|
||||
# send request to vllm endpoint
|
||||
client = OpenAI(
|
||||
base_url=f"{LLM_ENDPOINT}/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"temperature": TEMPERATURE,
|
||||
}
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=LLM_MODEL, messages=[{"role": "user", "content": prompt}], **params
|
||||
)
|
||||
|
||||
# get response
|
||||
response = completion.choices[0].message.content
|
||||
print(f"LLM Response: {response}")
|
||||
return response
|
||||
|
||||
|
||||
def parse_response(response):
|
||||
if "{" in response:
|
||||
ret = response.split("{")[1].split("}")[0]
|
||||
else:
|
||||
ret = ""
|
||||
return ret
|
||||
|
||||
|
||||
def get_company_list():
|
||||
kvstore = RedisKVStore(redis_uri=REDIS_URL_KV)
|
||||
company_list_dict = kvstore.get("company", "company_list")
|
||||
if company_list_dict:
|
||||
company_list = company_list_dict["company"]
|
||||
return company_list
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def get_company_name_in_kb(company, company_list):
|
||||
if not company_list:
|
||||
return "Database is empty."
|
||||
|
||||
company = company.upper()
|
||||
if company in company_list:
|
||||
return company
|
||||
|
||||
prompt = COMPANY_NAME_PROMPT.format(company_list=company_list, company=company)
|
||||
response = generate_answer(prompt)
|
||||
if "NONE" in response.upper():
|
||||
return f"Cannot find {company} in knowledge base."
|
||||
else:
|
||||
ret = parse_response(response)
|
||||
if ret:
|
||||
return ret
|
||||
else:
|
||||
return "Failed to parse LLM response."
|
||||
|
||||
|
||||
def get_docs_matching_metadata(metadata, collection_name):
|
||||
"""
|
||||
metadata: ("company_year", "3M_2023")
|
||||
docs: list of documents
|
||||
"""
|
||||
key = metadata[0]
|
||||
value = metadata[1]
|
||||
kvstore = RedisKVStore(redis_uri=REDIS_URL_KV)
|
||||
collection = kvstore.get_all(collection_name) # collection is a dict
|
||||
|
||||
matching_docs = []
|
||||
for idx in collection:
|
||||
doc = collection[idx]
|
||||
if doc["metadata"][key] == value:
|
||||
print(f"Found doc with matching metadata {metadata}")
|
||||
print(doc["metadata"]["doc_title"])
|
||||
matching_docs.append(doc)
|
||||
print(f"Number of docs found with search_metadata {metadata}: {len(matching_docs)}")
|
||||
return matching_docs
|
||||
|
||||
|
||||
def convert_docs(docs):
|
||||
# docs: list of dicts
|
||||
converted_docs_content = []
|
||||
converted_docs_summary = []
|
||||
for doc in docs:
|
||||
content = doc["content"]
|
||||
# convert content to Document object
|
||||
metadata = {"type": "content", **doc["metadata"]}
|
||||
converted_content = Document(id=doc["metadata"]["doc_id"], page_content=content, metadata=metadata)
|
||||
|
||||
# convert summary to Document object
|
||||
metadata = {"type": "summary", "content": content, **doc["metadata"]}
|
||||
converted_summary = Document(id=doc["metadata"]["doc_id"], page_content=doc["summary"], metadata=metadata)
|
||||
converted_docs_content.append(converted_content)
|
||||
converted_docs_summary.append(converted_summary)
|
||||
return converted_docs_content, converted_docs_summary
|
||||
|
||||
|
||||
def bm25_search(query, metadata, company, doc_type="chunks", k=10):
|
||||
collection_name = f"{doc_type}_{company}"
|
||||
print(f"Collection name: {collection_name}")
|
||||
|
||||
docs = get_docs_matching_metadata(metadata, collection_name)
|
||||
|
||||
if docs:
|
||||
docs_text, docs_summary = convert_docs(docs)
|
||||
# BM25 search over content
|
||||
retriever = BM25Retriever.from_documents(docs_text, k=k)
|
||||
docs_bm25 = retriever.invoke(query)
|
||||
print(f"BM25: Found {len(docs_bm25)} docs over content with search metadata: {metadata}")
|
||||
|
||||
# BM25 search over summary/title
|
||||
retriever = BM25Retriever.from_documents(docs_summary, k=k)
|
||||
docs_bm25_summary = retriever.invoke(query)
|
||||
print(f"BM25: Found {len(docs_bm25_summary)} docs over summary with search metadata: {metadata}")
|
||||
results = docs_bm25 + docs_bm25_summary
|
||||
else:
|
||||
results = []
|
||||
return results
|
||||
|
||||
|
||||
def bm25_search_broad(query, company, year, quarter, k=10, doc_type="chunks"):
|
||||
# search with company filter, but query is query_company_quarter
|
||||
metadata = ("company", f"{company}")
|
||||
query1 = f"{query} {year} {quarter}"
|
||||
docs1 = bm25_search(query1, metadata, company, k=k, doc_type=doc_type)
|
||||
|
||||
# search with metadata filters
|
||||
metadata = ("company_year_quarter", f"{company}_{year}_{quarter}")
|
||||
print(f"BM25: Searching for docs with metadata: {metadata}")
|
||||
docs = bm25_search(query, metadata, company, k=k, doc_type=doc_type)
|
||||
if not docs:
|
||||
print("BM25: No docs found with company, year and quarter filter, only search with company and year filter")
|
||||
metadata = ("company_year", f"{company}_{year}")
|
||||
docs = bm25_search(query, metadata, company, k=k, doc_type=doc_type)
|
||||
if not docs:
|
||||
print("BM25: No docs found with company and year filter, only search with company filter")
|
||||
metadata = ("company", f"{company}")
|
||||
docs = bm25_search(query, metadata, company, k=k, doc_type=doc_type)
|
||||
|
||||
docs = docs + docs1
|
||||
if docs:
|
||||
return docs
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def set_filter(metadata_filter):
|
||||
# metadata_filter: tuple of (key, value)
|
||||
from redisvl.query.filter import Text
|
||||
|
||||
key = metadata_filter[0]
|
||||
value = metadata_filter[1]
|
||||
filter_condition = Text(key) == value
|
||||
return filter_condition
|
||||
|
||||
|
||||
def similarity_search(vector_store, k, query, company, year, quarter=None):
|
||||
query1 = f"{query} {year} {quarter}"
|
||||
filter_condition = set_filter(("company", company))
|
||||
docs1 = vector_store.similarity_search(query1, k=k, filter=filter_condition)
|
||||
print(f"Similarity search: Found {len(docs1)} docs with company filter and query: {query1}")
|
||||
|
||||
filter_condition = set_filter(("company_year_quarter", f"{company}_{year}_{quarter}"))
|
||||
docs = vector_store.similarity_search(query, k=k, filter=filter_condition)
|
||||
|
||||
if not docs: # if no relevant document found, relax the filter
|
||||
print("No relevant document found with company, year and quarter filter, only search with company and year")
|
||||
filter_condition = set_filter(("company_year", f"{company}_{year}"))
|
||||
docs = vector_store.similarity_search(query, k=k, filter=filter_condition)
|
||||
|
||||
if not docs: # if no relevant document found, relax the filter
|
||||
print("No relevant document found with company_year filter, only search with company.....")
|
||||
filter_condition = set_filter(("company", company))
|
||||
docs = vector_store.similarity_search(query, k=k, filter=filter_condition)
|
||||
|
||||
print(f"Similarity search: Found {len(docs)} docs with filter and query: {query}")
|
||||
|
||||
docs = docs + docs1
|
||||
if not docs:
|
||||
return []
|
||||
else:
|
||||
return docs
|
||||
|
||||
|
||||
def get_index_name(doc_type: str, metadata: dict):
|
||||
company = metadata["company"]
|
||||
if doc_type == "chunks":
|
||||
index_name = f"chunks_{company}"
|
||||
elif doc_type == "tables":
|
||||
index_name = f"tables_{company}"
|
||||
elif doc_type == "titles":
|
||||
index_name = f"titles_{company}"
|
||||
elif doc_type == "full_doc":
|
||||
index_name = f"full_doc_{company}"
|
||||
else:
|
||||
raise ValueError("doc_type should be either chunks, tables, titles, or full_doc.")
|
||||
return index_name
|
||||
|
||||
|
||||
def get_content(doc):
|
||||
# doc can be converted doc
|
||||
# of saved doc in vector store
|
||||
if "type" in doc.metadata and doc.metadata["type"] == "summary":
|
||||
print("BM25 retrieved doc...")
|
||||
content = doc.metadata["content"]
|
||||
elif "type" in doc.metadata and doc.metadata["type"] == "content":
|
||||
print("BM25 retrieved doc...")
|
||||
content = doc.page_content
|
||||
else:
|
||||
print("Dense retriever doc...")
|
||||
|
||||
doc_id = doc.metadata["doc_id"]
|
||||
# doc_summary=doc.page_content
|
||||
kvstore = RedisKVStore(redis_uri=REDIS_URL_KV)
|
||||
collection_name = get_index_name(doc.metadata["doc_type"], doc.metadata)
|
||||
result = kvstore.get(doc_id, collection_name)
|
||||
content = result["content"]
|
||||
|
||||
# print(f"***Doc Metadata:\n{doc.metadata}")
|
||||
# print(f"***Content: {content[:100]}...")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def get_unique_docs(docs):
|
||||
results = []
|
||||
context = ""
|
||||
i = 1
|
||||
for doc in docs:
|
||||
content = get_content(doc)
|
||||
if content not in results:
|
||||
results.append(content)
|
||||
doc_title = doc.metadata["doc_title"]
|
||||
ret_doc = f"Doc [{i}] from {doc_title}:\n{content}\n"
|
||||
context += ret_doc
|
||||
i += 1
|
||||
print(f"Number of unique docs found: {len(results)}")
|
||||
return context
|
||||
|
||||
|
||||
def get_vectorstore(index_name):
|
||||
config = RedisConfig(
|
||||
index_name=index_name,
|
||||
redis_url=REDIS_URL_VECTOR,
|
||||
metadata_schema=[
|
||||
{"name": "company", "type": "text"},
|
||||
{"name": "year", "type": "text"},
|
||||
{"name": "quarter", "type": "text"},
|
||||
{"name": "doc_type", "type": "text"},
|
||||
{"name": "doc_title", "type": "text"},
|
||||
{"name": "doc_id", "type": "text"},
|
||||
{"name": "company_year", "type": "text"},
|
||||
{"name": "company_year_quarter", "type": "text"},
|
||||
],
|
||||
)
|
||||
embedder = get_embedder()
|
||||
vector_store = RedisVectorStore(embedder, config=config)
|
||||
return vector_store
|
||||
|
||||
|
||||
def get_vectorstore_titles(index_name):
|
||||
config = RedisConfig(
|
||||
index_name=index_name,
|
||||
redis_url=REDIS_URL_VECTOR,
|
||||
metadata_schema=[
|
||||
{"name": "company", "type": "text"},
|
||||
{"name": "year", "type": "text"},
|
||||
{"name": "quarter", "type": "text"},
|
||||
{"name": "doc_type", "type": "text"},
|
||||
{"name": "doc_title", "type": "text"},
|
||||
{"name": "company_year", "type": "text"},
|
||||
{"name": "company_year_quarter", "type": "text"},
|
||||
],
|
||||
)
|
||||
embedder = get_embedder()
|
||||
vector_store = RedisVectorStore(embedder, config=config)
|
||||
return vector_store
|
||||
Reference in New Issue
Block a user