move evaluation scripts (#842)

Co-authored-by: root <root@idc708073.jf.intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
lkk
2024-09-19 15:59:13 +08:00
committed by GitHub
parent 872e93e4bd
commit f04f061f8c
41 changed files with 478 additions and 7684 deletions

View File

@@ -0,0 +1,78 @@
# FaqGen Evaluation
## Dataset
We evaluate performance on QA dataset [Squad_v2](https://huggingface.co/datasets/rajpurkar/squad_v2). Generate FAQs on "context" columns in validation dataset, which contains 1204 unique records.
First download dataset and put at "./data".
Extract unique "context" columns, which will be save to 'data/sqv2_context.json':
```
python get_context.py
```
## Generate FAQs
### Launch FaQGen microservice
Please refer to [FaQGen microservice](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/faq-generation/tgi), set up an microservice endpoint.
```
export FAQ_ENDPOINT = "http://${your_ip}:9000/v1/faqgen"
```
### Generate FAQs with microservice
Use the microservice endpoint to generate FAQs for dataset.
```
python generate_FAQ.py
```
Post-process the output to get the right data, which will be save to 'data/sqv2_faq.json'.
```
python post_process_FAQ.py
```
## Evaluate with Ragas
### Launch TGI service
We use "mistralai/Mixtral-8x7B-Instruct-v0.1" as LLM referee to evaluate the model. First we need to launch a LLM endpoint on Gaudi.
```
export HUGGING_FACE_HUB_TOKEN="your_huggingface_token"
bash launch_tgi.sh
```
Get the endpoint:
```
export LLM_ENDPOINT = "http://${ip_address}:8082"
```
Verify the service:
```bash
curl http://${ip_address}:8082/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":128}}' \
-H 'Content-Type: application/json'
```
### Evaluate
evaluate the performance with the LLM:
```
python evaluate.py
```
### Performance Result
Here is the tested result for your reference
| answer_relevancy | faithfulness | context_utilization | reference_free_rubrics_score |
| ---- | ---- |---- |---- |
| 0.7191 | 0.9681 | 0.8964 | 4.4125|

View File

@@ -0,0 +1,44 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import json
import os
from evals.metrics.ragas import RagasMetric
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
llm_endpoint = os.getenv("LLM_ENDPOINT", "http://0.0.0.0:8082")
f = open("data/sqv2_context.json", "r")
sqv2_context = json.load(f)
f = open("data/sqv2_faq.json", "r")
sqv2_faq = json.load(f)
templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
TEXT: {text}
Do not use any prefix or suffix to the FAQ.
"""
number = 1204
question = []
answer = []
ground_truth = ["None"] * number
contexts = []
for i in range(number):
inputs = sqv2_context[str(i)]
inputs_faq = templ.format_map({"text": inputs})
actual_output = sqv2_faq[str(i)]
question.append(inputs_faq)
answer.append(actual_output)
contexts.append([inputs_faq])
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
metrics_faq = ["answer_relevancy", "faithfulness", "context_utilization", "reference_free_rubrics_score"]
metric = RagasMetric(threshold=0.5, model=llm_endpoint, embeddings=embeddings, metrics=metrics_faq)
test_case = {"question": question, "answer": answer, "ground_truth": ground_truth, "contexts": contexts}
metric.measure(test_case)
print(metric.score)

View File

@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import json
import os
import time
import requests
llm_endpoint = os.getenv("FAQ_ENDPOINT", "http://0.0.0.0:9000/v1/faqgen")
f = open("data/sqv2_context.json", "r")
sqv2_context = json.load(f)
start_time = time.time()
headers = {"Content-Type": "application/json"}
for i in range(1204):
start_time_tmp = time.time()
print(i)
inputs = sqv2_context[str(i)]
data = {"query": inputs, "max_new_tokens": 128}
response = requests.post(llm_endpoint, json=data, headers=headers)
f = open(f"data/result/sqv2_faq_{i}", "w")
f.write(inputs)
f.write(str(response.content, encoding="utf-8"))
f.close()
print(f"Cost {time.time()-start_time_tmp} seconds")
print(f"\n Finished! \n Totally Cost {time.time()-start_time} seconds\n")

View File

@@ -0,0 +1,17 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import json
import os
import pandas as pd
data_path = "./data"
data = pd.read_parquet(os.path.join(data_path, "squad_v2/squad_v2/validation-00000-of-00001.parquet"))
sq_context = list(data["context"].unique())
sq_context_d = dict()
for i in range(len(sq_context)):
sq_context_d[i] = sq_context[i]
with open(os.path.join(data_path, "sqv2_context.json"), "w") as outfile:
json.dump(sq_context_d, outfile)

View File

@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
max_input_tokens=3072
max_total_tokens=4096
port_number=8082
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
volume="./data"
docker run -it --rm \
--name="tgi_Mixtral" \
-p $port_number:80 \
-v $volume:/data \
--runtime=habana \
--restart always \
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
--cap-add=sys_nice \
--ipc=host \
-e HTTPS_PROXY=$https_proxy \
-e HTTP_PROXY=$https_proxy \
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
--model-id $model_name \
--max-input-tokens $max_input_tokens \
--max-total-tokens $max_total_tokens \
--sharded true \
--num-shard 2

View File

@@ -0,0 +1,27 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import json
faq_dict = {}
fails = []
for i in range(1204):
data = open(f"data/result/sqv2_faq_{i}", "r").readlines()
result = data[-6][6:]
# print(result)
if "LLMChain/final_output" not in result:
print(f"error1: fail for {i}")
fails.append(i)
continue
try:
result2 = json.loads(result)
result3 = result2["ops"][0]["value"]["text"]
faq_dict[str(i)] = result3
except:
print(f"error2: fail for {i}")
fails.append(i)
continue
with open("data/sqv2_faq.json", "w") as outfile:
json.dump(faq_dict, outfile)
print("Failure index:")
print(fails)