Files
GenAIExamples/ChatQnA/benchmark/accuracy/eval_multihop.py
lkk 3372b9d480 update accuracy embedding endpoint for no wrapper (#1056)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-11-04 09:18:49 +08:00

280 lines
10 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os
import requests
from evals.evaluation.rag_eval import Evaluator
from evals.metrics.ragas import RagasMetric
from evals.metrics.retrieval import RetrievalBaseMetric
from tqdm import tqdm
class MultiHop_Evaluator(Evaluator):
def get_ground_truth_text(self, data: dict):
return data["answer"]
def get_query(self, data: dict):
return data["query"]
def get_template(self):
return None
def get_reranked_documents(self, query, docs, arguments):
data = {
"initial_query": query,
"retrieved_docs": [{"text": doc} for doc in docs],
"top_n": 10,
}
headers = {"Content-Type": "application/json"}
response = requests.post(arguments.reranking_endpoint, data=json.dumps(data), headers=headers)
if response.ok:
reranked_documents = response.json()["documents"]
return reranked_documents
else:
print(f"Request for retrieval failed due to {response.text}.")
return []
def get_retrieved_documents(self, query, arguments):
data = {"inputs": query}
headers = {"Content-Type": "application/json"}
response = requests.post(arguments.tei_embedding_endpoint + "/embed", data=json.dumps(data), headers=headers)
if response.ok:
embedding = response.json()[0]
else:
print(f"Request for embedding failed due to {response.text}.")
return []
data = {
"text": query,
"embedding": embedding,
"search_type": arguments.search_type,
"k": arguments.retrival_k,
"fetch_k": arguments.fetch_k,
"lambda_mult": arguments.lambda_mult,
}
response = requests.post(arguments.retrieval_endpoint, data=json.dumps(data), headers=headers)
if response.ok:
retrieved_documents = response.json()["retrieved_docs"]
return [doc["text"] for doc in retrieved_documents]
else:
print(f"Request for retrieval failed due to {response.text}.")
return []
def get_retrieval_metrics(self, all_queries, arguments):
print("start to retrieve...")
metric = RetrievalBaseMetric()
hits_at_10 = 0
hits_at_4 = 0
map_at_10 = 0
mrr_at_10 = 0
total = 0
for data in tqdm(all_queries):
if data["question_type"] == "null_query":
continue
query = data["query"]
retrieved_documents = self.get_retrieved_documents(query, arguments)
if arguments.rerank:
retrieved_documents = self.get_reranked_documents(query, retrieved_documents, arguments)
golden_context = [each["fact"] for each in data["evidence_list"]]
test_case = {
"input": query,
"golden_context": golden_context,
"retrieval_context": retrieved_documents,
}
results = metric.measure(test_case)
hits_at_10 += results["Hits@10"]
hits_at_4 += results["Hits@4"]
map_at_10 += results["MAP@10"]
mrr_at_10 += results["MRR@10"]
total += 1
# Calculate average metrics over all queries
hits_at_10 = hits_at_10 / total
hits_at_4 = hits_at_4 / total
map_at_10 = map_at_10 / total
mrr_at_10 = mrr_at_10 / total
return {
"Hits@10": hits_at_10,
"Hits@4": hits_at_4,
"MAP@10": map_at_10,
"MRR@10": mrr_at_10,
}
def evaluate(self, all_queries, arguments):
results = []
accuracy = 0
index = 0
for data in tqdm(all_queries):
if data["question_type"] == "null_query":
continue
generated_text = self.send_request(data, arguments)
data["generated_text"] = generated_text
# same method with paper: https://github.com/yixuantt/MultiHop-RAG/issues/8
if data["answer"] in generated_text:
accuracy += 1
result = {"id": index, **self.scoring(data)}
results.append(result)
index += 1
valid_results = self.remove_invalid(results)
try:
overall = self.compute_overall(valid_results) if len(valid_results) > 0 else {}
except Exception as e:
print(repr(e))
overall = dict()
overall.update({"accuracy": accuracy / len(results)})
return overall
def get_ragas_metrics(self, all_queries, arguments):
from langchain_huggingface import HuggingFaceEndpointEmbeddings
embeddings = HuggingFaceEndpointEmbeddings(model=arguments.tei_embedding_endpoint)
metric = RagasMetric(threshold=0.5, model=arguments.llm_endpoint, embeddings=embeddings)
all_answer_relevancy = 0
all_faithfulness = 0
ragas_inputs = {
"question": [],
"answer": [],
"ground_truth": [],
"contexts": [],
}
for data in tqdm(all_queries):
if data["question_type"] == "null_query":
continue
retrieved_documents = self.get_retrieved_documents(data["query"], arguments)
generated_text = self.send_request(data, arguments)
data["generated_text"] = generated_text
ragas_inputs["question"].append(data["query"])
ragas_inputs["answer"].append(generated_text)
ragas_inputs["ground_truth"].append(data["answer"])
ragas_inputs["contexts"].append(retrieved_documents[:3])
if len(ragas_inputs["question"]) >= arguments.limits:
break
ragas_metrics = metric.measure(ragas_inputs)
return ragas_metrics
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--service_url", type=str, default="http://localhost:8888/v1/chatqna", help="Service URL address."
)
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save evaluation results.")
parser.add_argument(
"--temperature", type=float, default=0.1, help="Controls the randomness of the model's text generation"
)
parser.add_argument(
"--max_new_tokens", type=int, default=1280, help="Maximum number of new tokens to be generated by the model"
)
parser.add_argument(
"--chunk_size", type=int, default=256, help="the maximum number of characters that a chunk can contain"
)
parser.add_argument(
"--chunk_overlap",
type=int,
default=100,
help="the number of characters that should overlap between two adjacent chunks",
)
parser.add_argument("--search_type", type=str, default="similarity", help="similarity type")
parser.add_argument("--retrival_k", type=int, default=10, help="Number of Documents to return.")
parser.add_argument(
"--fetch_k", type=int, default=20, help="Number of Documents to fetch to pass to MMR algorithm."
)
parser.add_argument(
"--lambda_mult",
type=float,
default=0.5,
help="Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.",
)
parser.add_argument("--dataset_path", default=None, help="Path to the dataset")
parser.add_argument("--docs_path", default=None, help="Path to the retrieval documents")
# Retriever related options
parser.add_argument("--ingest_docs", action="store_true", help="Whether to ingest documents to vector database")
parser.add_argument("--retrieval_metrics", action="store_true", help="Whether to compute retrieval metrics.")
parser.add_argument("--ragas_metrics", action="store_true", help="Whether to compute ragas metrics.")
parser.add_argument("--limits", type=int, default=100, help="Number of examples to be evaluated by llm-as-judge")
parser.add_argument(
"--database_endpoint", type=str, default="http://localhost:6007/v1/dataprep", help="Service URL address."
)
parser.add_argument(
"--embedding_endpoint", type=str, default="http://localhost:6000/v1/embeddings", help="Service URL address."
)
parser.add_argument(
"--tei_embedding_endpoint",
type=str,
default="http://localhost:8090",
help="Service URL address of tei embedding.",
)
parser.add_argument(
"--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address."
)
parser.add_argument("--rerank", action="store_true", help="Whether to use rerank microservice.")
parser.add_argument(
"--reranking_endpoint", type=str, default="http://localhost:8000/v1/reranking", help="Service URL address."
)
parser.add_argument("--llm_endpoint", type=str, default=None, help="Service URL address.")
parser.add_argument(
"--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar"
)
parser.add_argument("--contain_original_data", action="store_true", help="Whether to contain original data")
args = parser.parse_args()
return args
def main():
args = args_parser()
evaluator = MultiHop_Evaluator()
with open(args.docs_path, "r") as file:
doc_data = json.load(file)
documents = []
for doc in doc_data:
metadata = {"title": doc["title"], "published_at": doc["published_at"], "source": doc["source"]}
documents.append(doc["body"])
# save docs to a tmp file
tmp_corpus_file = "tmp_corpus.txt"
with open(tmp_corpus_file, "w") as f:
for doc in documents:
f.write(doc + "\n")
if args.ingest_docs:
evaluator.ingest_docs(tmp_corpus_file, args.database_endpoint, args.chunk_size, args.chunk_overlap)
with open(args.dataset_path, "r") as file:
all_queries = json.load(file)
# get retrieval quality
if args.retrieval_metrics:
retrieval_metrics = evaluator.get_retrieval_metrics(all_queries, args)
print(retrieval_metrics)
# get rag quality
if args.ragas_metrics:
ragas_metrics = evaluator.get_ragas_metrics(all_queries, args)
print(ragas_metrics)
if __name__ == "__main__":
main()