Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
280 lines
10 KiB
Python
280 lines
10 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# Copyright (C) 2024 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
|
|
import requests
|
|
from evals.evaluation.rag_eval import Evaluator
|
|
from evals.metrics.ragas import RagasMetric
|
|
from evals.metrics.retrieval import RetrievalBaseMetric
|
|
from tqdm import tqdm
|
|
|
|
|
|
class MultiHop_Evaluator(Evaluator):
|
|
def get_ground_truth_text(self, data: dict):
|
|
return data["answer"]
|
|
|
|
def get_query(self, data: dict):
|
|
return data["query"]
|
|
|
|
def get_template(self):
|
|
return None
|
|
|
|
def get_reranked_documents(self, query, docs, arguments):
|
|
data = {
|
|
"initial_query": query,
|
|
"retrieved_docs": [{"text": doc} for doc in docs],
|
|
"top_n": 10,
|
|
}
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
response = requests.post(arguments.reranking_endpoint, data=json.dumps(data), headers=headers)
|
|
if response.ok:
|
|
reranked_documents = response.json()["documents"]
|
|
return reranked_documents
|
|
else:
|
|
print(f"Request for retrieval failed due to {response.text}.")
|
|
return []
|
|
|
|
def get_retrieved_documents(self, query, arguments):
|
|
data = {"inputs": query}
|
|
headers = {"Content-Type": "application/json"}
|
|
response = requests.post(arguments.tei_embedding_endpoint + "/embed", data=json.dumps(data), headers=headers)
|
|
if response.ok:
|
|
embedding = response.json()[0]
|
|
else:
|
|
print(f"Request for embedding failed due to {response.text}.")
|
|
return []
|
|
data = {
|
|
"text": query,
|
|
"embedding": embedding,
|
|
"search_type": arguments.search_type,
|
|
"k": arguments.retrival_k,
|
|
"fetch_k": arguments.fetch_k,
|
|
"lambda_mult": arguments.lambda_mult,
|
|
}
|
|
response = requests.post(arguments.retrieval_endpoint, data=json.dumps(data), headers=headers)
|
|
if response.ok:
|
|
retrieved_documents = response.json()["retrieved_docs"]
|
|
return [doc["text"] for doc in retrieved_documents]
|
|
else:
|
|
print(f"Request for retrieval failed due to {response.text}.")
|
|
return []
|
|
|
|
def get_retrieval_metrics(self, all_queries, arguments):
|
|
print("start to retrieve...")
|
|
metric = RetrievalBaseMetric()
|
|
hits_at_10 = 0
|
|
hits_at_4 = 0
|
|
map_at_10 = 0
|
|
mrr_at_10 = 0
|
|
total = 0
|
|
for data in tqdm(all_queries):
|
|
if data["question_type"] == "null_query":
|
|
continue
|
|
query = data["query"]
|
|
retrieved_documents = self.get_retrieved_documents(query, arguments)
|
|
if arguments.rerank:
|
|
retrieved_documents = self.get_reranked_documents(query, retrieved_documents, arguments)
|
|
golden_context = [each["fact"] for each in data["evidence_list"]]
|
|
test_case = {
|
|
"input": query,
|
|
"golden_context": golden_context,
|
|
"retrieval_context": retrieved_documents,
|
|
}
|
|
results = metric.measure(test_case)
|
|
hits_at_10 += results["Hits@10"]
|
|
hits_at_4 += results["Hits@4"]
|
|
map_at_10 += results["MAP@10"]
|
|
mrr_at_10 += results["MRR@10"]
|
|
total += 1
|
|
|
|
# Calculate average metrics over all queries
|
|
hits_at_10 = hits_at_10 / total
|
|
hits_at_4 = hits_at_4 / total
|
|
map_at_10 = map_at_10 / total
|
|
mrr_at_10 = mrr_at_10 / total
|
|
|
|
return {
|
|
"Hits@10": hits_at_10,
|
|
"Hits@4": hits_at_4,
|
|
"MAP@10": map_at_10,
|
|
"MRR@10": mrr_at_10,
|
|
}
|
|
|
|
def evaluate(self, all_queries, arguments):
|
|
results = []
|
|
accuracy = 0
|
|
index = 0
|
|
for data in tqdm(all_queries):
|
|
if data["question_type"] == "null_query":
|
|
continue
|
|
|
|
generated_text = self.send_request(data, arguments)
|
|
data["generated_text"] = generated_text
|
|
|
|
# same method with paper: https://github.com/yixuantt/MultiHop-RAG/issues/8
|
|
if data["answer"] in generated_text:
|
|
accuracy += 1
|
|
result = {"id": index, **self.scoring(data)}
|
|
results.append(result)
|
|
index += 1
|
|
|
|
valid_results = self.remove_invalid(results)
|
|
|
|
try:
|
|
overall = self.compute_overall(valid_results) if len(valid_results) > 0 else {}
|
|
except Exception as e:
|
|
print(repr(e))
|
|
overall = dict()
|
|
|
|
overall.update({"accuracy": accuracy / len(results)})
|
|
return overall
|
|
|
|
def get_ragas_metrics(self, all_queries, arguments):
|
|
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
|
|
|
embeddings = HuggingFaceEndpointEmbeddings(model=arguments.tei_embedding_endpoint)
|
|
|
|
metric = RagasMetric(threshold=0.5, model=arguments.llm_endpoint, embeddings=embeddings)
|
|
all_answer_relevancy = 0
|
|
all_faithfulness = 0
|
|
ragas_inputs = {
|
|
"question": [],
|
|
"answer": [],
|
|
"ground_truth": [],
|
|
"contexts": [],
|
|
}
|
|
|
|
for data in tqdm(all_queries):
|
|
if data["question_type"] == "null_query":
|
|
continue
|
|
retrieved_documents = self.get_retrieved_documents(data["query"], arguments)
|
|
generated_text = self.send_request(data, arguments)
|
|
data["generated_text"] = generated_text
|
|
|
|
ragas_inputs["question"].append(data["query"])
|
|
ragas_inputs["answer"].append(generated_text)
|
|
ragas_inputs["ground_truth"].append(data["answer"])
|
|
ragas_inputs["contexts"].append(retrieved_documents[:3])
|
|
|
|
if len(ragas_inputs["question"]) >= arguments.limits:
|
|
break
|
|
|
|
ragas_metrics = metric.measure(ragas_inputs)
|
|
return ragas_metrics
|
|
|
|
|
|
def args_parser():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--service_url", type=str, default="http://localhost:8888/v1/chatqna", help="Service URL address."
|
|
)
|
|
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save evaluation results.")
|
|
parser.add_argument(
|
|
"--temperature", type=float, default=0.1, help="Controls the randomness of the model's text generation"
|
|
)
|
|
parser.add_argument(
|
|
"--max_new_tokens", type=int, default=1280, help="Maximum number of new tokens to be generated by the model"
|
|
)
|
|
parser.add_argument(
|
|
"--chunk_size", type=int, default=256, help="the maximum number of characters that a chunk can contain"
|
|
)
|
|
parser.add_argument(
|
|
"--chunk_overlap",
|
|
type=int,
|
|
default=100,
|
|
help="the number of characters that should overlap between two adjacent chunks",
|
|
)
|
|
parser.add_argument("--search_type", type=str, default="similarity", help="similarity type")
|
|
parser.add_argument("--retrival_k", type=int, default=10, help="Number of Documents to return.")
|
|
parser.add_argument(
|
|
"--fetch_k", type=int, default=20, help="Number of Documents to fetch to pass to MMR algorithm."
|
|
)
|
|
parser.add_argument(
|
|
"--lambda_mult",
|
|
type=float,
|
|
default=0.5,
|
|
help="Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.",
|
|
)
|
|
parser.add_argument("--dataset_path", default=None, help="Path to the dataset")
|
|
parser.add_argument("--docs_path", default=None, help="Path to the retrieval documents")
|
|
|
|
# Retriever related options
|
|
parser.add_argument("--ingest_docs", action="store_true", help="Whether to ingest documents to vector database")
|
|
parser.add_argument("--retrieval_metrics", action="store_true", help="Whether to compute retrieval metrics.")
|
|
parser.add_argument("--ragas_metrics", action="store_true", help="Whether to compute ragas metrics.")
|
|
parser.add_argument("--limits", type=int, default=100, help="Number of examples to be evaluated by llm-as-judge")
|
|
parser.add_argument(
|
|
"--database_endpoint", type=str, default="http://localhost:6007/v1/dataprep", help="Service URL address."
|
|
)
|
|
parser.add_argument(
|
|
"--embedding_endpoint", type=str, default="http://localhost:6000/v1/embeddings", help="Service URL address."
|
|
)
|
|
parser.add_argument(
|
|
"--tei_embedding_endpoint",
|
|
type=str,
|
|
default="http://localhost:8090",
|
|
help="Service URL address of tei embedding.",
|
|
)
|
|
parser.add_argument(
|
|
"--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address."
|
|
)
|
|
parser.add_argument("--rerank", action="store_true", help="Whether to use rerank microservice.")
|
|
parser.add_argument(
|
|
"--reranking_endpoint", type=str, default="http://localhost:8000/v1/reranking", help="Service URL address."
|
|
)
|
|
parser.add_argument("--llm_endpoint", type=str, default=None, help="Service URL address.")
|
|
parser.add_argument(
|
|
"--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar"
|
|
)
|
|
parser.add_argument("--contain_original_data", action="store_true", help="Whether to contain original data")
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = args_parser()
|
|
|
|
evaluator = MultiHop_Evaluator()
|
|
|
|
with open(args.docs_path, "r") as file:
|
|
doc_data = json.load(file)
|
|
|
|
documents = []
|
|
for doc in doc_data:
|
|
metadata = {"title": doc["title"], "published_at": doc["published_at"], "source": doc["source"]}
|
|
documents.append(doc["body"])
|
|
|
|
# save docs to a tmp file
|
|
tmp_corpus_file = "tmp_corpus.txt"
|
|
with open(tmp_corpus_file, "w") as f:
|
|
for doc in documents:
|
|
f.write(doc + "\n")
|
|
|
|
if args.ingest_docs:
|
|
evaluator.ingest_docs(tmp_corpus_file, args.database_endpoint, args.chunk_size, args.chunk_overlap)
|
|
|
|
with open(args.dataset_path, "r") as file:
|
|
all_queries = json.load(file)
|
|
|
|
# get retrieval quality
|
|
if args.retrieval_metrics:
|
|
retrieval_metrics = evaluator.get_retrieval_metrics(all_queries, args)
|
|
print(retrieval_metrics)
|
|
|
|
# get rag quality
|
|
if args.ragas_metrics:
|
|
ragas_metrics = evaluator.get_ragas_metrics(all_queries, args)
|
|
print(ragas_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|