move evaluation scripts (#842)
Co-authored-by: root <root@idc708073.jf.intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
78
FaqGen/benchmark/accuracy/README.md
Normal file
78
FaqGen/benchmark/accuracy/README.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# FaqGen Evaluation
|
||||
|
||||
## Dataset
|
||||
|
||||
We evaluate performance on QA dataset [Squad_v2](https://huggingface.co/datasets/rajpurkar/squad_v2). Generate FAQs on "context" columns in validation dataset, which contains 1204 unique records.
|
||||
|
||||
First download dataset and put at "./data".
|
||||
|
||||
Extract unique "context" columns, which will be save to 'data/sqv2_context.json':
|
||||
|
||||
```
|
||||
python get_context.py
|
||||
```
|
||||
|
||||
## Generate FAQs
|
||||
|
||||
### Launch FaQGen microservice
|
||||
|
||||
Please refer to [FaQGen microservice](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/faq-generation/tgi), set up an microservice endpoint.
|
||||
|
||||
```
|
||||
export FAQ_ENDPOINT = "http://${your_ip}:9000/v1/faqgen"
|
||||
```
|
||||
|
||||
### Generate FAQs with microservice
|
||||
|
||||
Use the microservice endpoint to generate FAQs for dataset.
|
||||
|
||||
```
|
||||
python generate_FAQ.py
|
||||
```
|
||||
|
||||
Post-process the output to get the right data, which will be save to 'data/sqv2_faq.json'.
|
||||
|
||||
```
|
||||
python post_process_FAQ.py
|
||||
```
|
||||
|
||||
## Evaluate with Ragas
|
||||
|
||||
### Launch TGI service
|
||||
|
||||
We use "mistralai/Mixtral-8x7B-Instruct-v0.1" as LLM referee to evaluate the model. First we need to launch a LLM endpoint on Gaudi.
|
||||
|
||||
```
|
||||
export HUGGING_FACE_HUB_TOKEN="your_huggingface_token"
|
||||
bash launch_tgi.sh
|
||||
```
|
||||
|
||||
Get the endpoint:
|
||||
|
||||
```
|
||||
export LLM_ENDPOINT = "http://${ip_address}:8082"
|
||||
```
|
||||
|
||||
Verify the service:
|
||||
|
||||
```bash
|
||||
curl http://${ip_address}:8082/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":128}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
|
||||
evaluate the performance with the LLM:
|
||||
|
||||
```
|
||||
python evaluate.py
|
||||
```
|
||||
|
||||
### Performance Result
|
||||
|
||||
Here is the tested result for your reference
|
||||
| answer_relevancy | faithfulness | context_utilization | reference_free_rubrics_score |
|
||||
| ---- | ---- |---- |---- |
|
||||
| 0.7191 | 0.9681 | 0.8964 | 4.4125|
|
||||
44
FaqGen/benchmark/accuracy/evaluate.py
Normal file
44
FaqGen/benchmark/accuracy/evaluate.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from evals.metrics.ragas import RagasMetric
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
llm_endpoint = os.getenv("LLM_ENDPOINT", "http://0.0.0.0:8082")
|
||||
|
||||
f = open("data/sqv2_context.json", "r")
|
||||
sqv2_context = json.load(f)
|
||||
|
||||
f = open("data/sqv2_faq.json", "r")
|
||||
sqv2_faq = json.load(f)
|
||||
|
||||
templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
|
||||
TEXT: {text}
|
||||
Do not use any prefix or suffix to the FAQ.
|
||||
"""
|
||||
|
||||
number = 1204
|
||||
question = []
|
||||
answer = []
|
||||
ground_truth = ["None"] * number
|
||||
contexts = []
|
||||
for i in range(number):
|
||||
inputs = sqv2_context[str(i)]
|
||||
inputs_faq = templ.format_map({"text": inputs})
|
||||
actual_output = sqv2_faq[str(i)]
|
||||
|
||||
question.append(inputs_faq)
|
||||
answer.append(actual_output)
|
||||
contexts.append([inputs_faq])
|
||||
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
||||
metrics_faq = ["answer_relevancy", "faithfulness", "context_utilization", "reference_free_rubrics_score"]
|
||||
metric = RagasMetric(threshold=0.5, model=llm_endpoint, embeddings=embeddings, metrics=metrics_faq)
|
||||
|
||||
test_case = {"question": question, "answer": answer, "ground_truth": ground_truth, "contexts": contexts}
|
||||
|
||||
metric.measure(test_case)
|
||||
print(metric.score)
|
||||
28
FaqGen/benchmark/accuracy/generate_FAQ.py
Normal file
28
FaqGen/benchmark/accuracy/generate_FAQ.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
llm_endpoint = os.getenv("FAQ_ENDPOINT", "http://0.0.0.0:9000/v1/faqgen")
|
||||
|
||||
f = open("data/sqv2_context.json", "r")
|
||||
sqv2_context = json.load(f)
|
||||
|
||||
start_time = time.time()
|
||||
headers = {"Content-Type": "application/json"}
|
||||
for i in range(1204):
|
||||
start_time_tmp = time.time()
|
||||
print(i)
|
||||
inputs = sqv2_context[str(i)]
|
||||
data = {"query": inputs, "max_new_tokens": 128}
|
||||
response = requests.post(llm_endpoint, json=data, headers=headers)
|
||||
f = open(f"data/result/sqv2_faq_{i}", "w")
|
||||
f.write(inputs)
|
||||
f.write(str(response.content, encoding="utf-8"))
|
||||
f.close()
|
||||
print(f"Cost {time.time()-start_time_tmp} seconds")
|
||||
print(f"\n Finished! \n Totally Cost {time.time()-start_time} seconds\n")
|
||||
17
FaqGen/benchmark/accuracy/get_context.py
Normal file
17
FaqGen/benchmark/accuracy/get_context.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
data_path = "./data"
|
||||
data = pd.read_parquet(os.path.join(data_path, "squad_v2/squad_v2/validation-00000-of-00001.parquet"))
|
||||
sq_context = list(data["context"].unique())
|
||||
sq_context_d = dict()
|
||||
for i in range(len(sq_context)):
|
||||
sq_context_d[i] = sq_context[i]
|
||||
|
||||
with open(os.path.join(data_path, "sqv2_context.json"), "w") as outfile:
|
||||
json.dump(sq_context_d, outfile)
|
||||
28
FaqGen/benchmark/accuracy/launch_tgi.sh
Normal file
28
FaqGen/benchmark/accuracy/launch_tgi.sh
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
max_input_tokens=3072
|
||||
max_total_tokens=4096
|
||||
port_number=8082
|
||||
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
volume="./data"
|
||||
docker run -it --rm \
|
||||
--name="tgi_Mixtral" \
|
||||
-p $port_number:80 \
|
||||
-v $volume:/data \
|
||||
--runtime=habana \
|
||||
--restart always \
|
||||
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
||||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-e HTTPS_PROXY=$https_proxy \
|
||||
-e HTTP_PROXY=$https_proxy \
|
||||
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
|
||||
--model-id $model_name \
|
||||
--max-input-tokens $max_input_tokens \
|
||||
--max-total-tokens $max_total_tokens \
|
||||
--sharded true \
|
||||
--num-shard 2
|
||||
27
FaqGen/benchmark/accuracy/post_process_FAQ.py
Normal file
27
FaqGen/benchmark/accuracy/post_process_FAQ.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
|
||||
faq_dict = {}
|
||||
fails = []
|
||||
for i in range(1204):
|
||||
data = open(f"data/result/sqv2_faq_{i}", "r").readlines()
|
||||
result = data[-6][6:]
|
||||
# print(result)
|
||||
if "LLMChain/final_output" not in result:
|
||||
print(f"error1: fail for {i}")
|
||||
fails.append(i)
|
||||
continue
|
||||
try:
|
||||
result2 = json.loads(result)
|
||||
result3 = result2["ops"][0]["value"]["text"]
|
||||
faq_dict[str(i)] = result3
|
||||
except:
|
||||
print(f"error2: fail for {i}")
|
||||
fails.append(i)
|
||||
continue
|
||||
with open("data/sqv2_faq.json", "w") as outfile:
|
||||
json.dump(faq_dict, outfile)
|
||||
print("Failure index:")
|
||||
print(fails)
|
||||
Reference in New Issue
Block a user