Files
GenAIExamples/ChatQnA/benchmark/accuracy/eval_crud.py
lkk 088ab98f31 update examples accuracy (#941)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-10-14 13:20:50 +08:00

211 lines
8.2 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os
from evals.evaluation.rag_eval import Evaluator
from evals.evaluation.rag_eval.template import CRUDTemplate
from evals.metrics.ragas import RagasMetric
from tqdm import tqdm
class CRUD_Evaluator(Evaluator):
def get_ground_truth_text(self, data: dict):
if self.task == "summarization":
ground_truth_text = data["summary"]
elif self.task == "question_answering":
ground_truth_text = data["answers"]
elif self.task == "continuation":
ground_truth_text = data["continuing"]
elif self.task == "hallucinated_modified":
ground_truth_text = data["hallucinatedMod"]
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return ground_truth_text
def get_query(self, data: dict):
if self.task == "summarization":
query = data["text"]
elif self.task == "question_answering":
query = data["questions"]
elif self.task == "continuation":
query = data["beginning"]
elif self.task == "hallucinated_modified":
query = data["newsBeginning"]
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return query
def get_document(self, data: dict):
if self.task == "summarization":
document = data["text"]
elif self.task == "question_answering":
document = data["news1"]
elif self.task == "continuation":
document = data["beginning"]
elif self.task == "hallucinated_modified":
document = data["newsBeginning"]
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return document
def get_template(self):
if self.task == "summarization":
template = CRUDTemplate.get_summarization_template()
elif self.task == "question_answering":
template = CRUDTemplate.get_question_answering_template()
elif self.task == "continuation":
template = CRUDTemplate.get_continuation_template()
else:
raise NotImplementedError(
f"Unknown task {self.task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
return template
def post_process(self, result):
return result.split("<response>")[-1].split("</response>")[0].strip()
def get_ragas_metrics(self, results, arguments):
from langchain_huggingface import HuggingFaceEndpointEmbeddings
embeddings = HuggingFaceEndpointEmbeddings(model=arguments.tei_embedding_endpoint)
metric = RagasMetric(
threshold=0.5,
model=arguments.llm_endpoint,
embeddings=embeddings,
metrics=["faithfulness", "answer_relevancy"],
)
all_answer_relevancy = 0
all_faithfulness = 0
ragas_inputs = {
"question": [],
"answer": [],
"ground_truth": [],
"contexts": [],
}
valid_results = self.remove_invalid(results["results"])
for data in tqdm(valid_results):
data = data["original_data"]
query = self.get_query(data)
generated_text = data["generated_text"]
ground_truth = data["ground_truth_text"]
retrieved_documents = data["retrieved_documents"]
ragas_inputs["question"].append(query)
ragas_inputs["answer"].append(generated_text)
ragas_inputs["ground_truth"].append(ground_truth)
ragas_inputs["contexts"].append(retrieved_documents[:3])
ragas_metrics = metric.measure(ragas_inputs)
return ragas_metrics
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--service_url", type=str, default="http://localhost:8888/v1/chatqna", help="Service URL address."
)
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save evaluation results.")
parser.add_argument(
"--temperature", type=float, default=0.1, help="Controls the randomness of the model's text generation"
)
parser.add_argument(
"--max_new_tokens", type=int, default=1280, help="Maximum number of new tokens to be generated by the model"
)
parser.add_argument(
"--chunk_size", type=int, default=256, help="the maximum number of characters that a chunk can contain"
)
parser.add_argument(
"--chunk_overlap",
type=int,
default=100,
help="the number of characters that should overlap between two adjacent chunks",
)
parser.add_argument("--dataset_path", default="../data/split_merged.json", help="Path to the dataset")
parser.add_argument("--docs_path", default="../data/80000_docs", help="Path to the retrieval documents")
# Retriever related options
parser.add_argument("--tasks", default=["question_answering"], nargs="+", help="Task to perform")
parser.add_argument("--ingest_docs", action="store_true", help="Whether to ingest documents to vector database")
parser.add_argument(
"--database_endpoint", type=str, default="http://localhost:6007/v1/dataprep", help="Service URL address."
)
parser.add_argument(
"--embedding_endpoint", type=str, default="http://localhost:6000/v1/embeddings", help="Service URL address."
)
parser.add_argument(
"--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address."
)
parser.add_argument(
"--tei_embedding_endpoint",
type=str,
default="http://localhost:8090",
help="Service URL address of tei embedding.",
)
parser.add_argument("--ragas_metrics", action="store_true", help="Whether to compute ragas metrics.")
parser.add_argument("--llm_endpoint", type=str, default=None, help="Service URL address.")
parser.add_argument(
"--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar"
)
parser.add_argument("--contain_original_data", action="store_true", help="Whether to contain original data")
args = parser.parse_args()
return args
def main():
args = args_parser()
if os.path.isfile(args.dataset_path):
with open(args.dataset_path) as f:
all_datasets = json.load(f)
else:
raise FileNotFoundError(f"Evaluation dataset file {args.dataset_path} not exist.")
os.makedirs(args.output_dir, exist_ok=True)
for task in args.tasks:
if task == "question_answering":
dataset = all_datasets["questanswer_1doc"]
elif task == "summarization":
dataset = all_datasets["event_summary"]
else:
raise NotImplementedError(
f"Unknown task {task}, only support "
"summarization, question_answering, continuation and hallucinated_modified."
)
output_save_path = os.path.join(args.output_dir, f"{task}.json")
evaluator = CRUD_Evaluator(dataset=dataset, output_path=output_save_path, task=task)
if args.ingest_docs:
CRUD_Evaluator.ingest_docs(args.docs_path, args.database_endpoint, args.chunk_size, args.chunk_overlap)
results = evaluator.evaluate(
args, show_progress_bar=args.show_progress_bar, contain_original_data=args.contain_original_data
)
print(results["overall"])
if args.ragas_metrics:
ragas_metrics = evaluator.get_ragas_metrics(results, args)
print(ragas_metrics)
print(f"Evaluation results of task {task} saved to {output_save_path}.")
if __name__ == "__main__":
main()