Minor fix DocIndexRetriever test (#1266)

This commit is contained in:
Sihan Chen
2024-12-19 12:12:33 +08:00
committed by GitHub
parent 7d9b34cf5e
commit 3b9e55cb8e
3 changed files with 6 additions and 86 deletions

View File

@@ -1,70 +0,0 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import argparse
import requests
def search_knowledge_base(query: str, url: str, request_type: str) -> str:
"""Search the knowledge base for a specific query."""
print(url)
proxies = {"http": ""}
if request_type == "chat_completion":
print("Sending chat completion request")
payload = {
"messages": query,
"k": 5,
"top_n": 2,
}
else:
print("Sending textdoc request")
payload = {
"text": query,
}
response = requests.post(url, json=payload, proxies=proxies)
print(response)
print(response.json().keys())
if "documents" in response.json():
docs = response.json()["documents"]
context = ""
for i, doc in enumerate(docs):
if i == 0:
context = str(i) + ": " + doc
else:
context += "\n" + str(i) + ": " + doc
return context
elif "text" in response.json():
return response.json()["text"]
elif "reranked_docs" in response.json():
docs = response.json()["reranked_docs"]
context = ""
for i, doc in enumerate(docs):
if i == 0:
context = doc["text"]
else:
context += "\n" + doc["text"]
return context
else:
return "Error parsing response from the knowledge base."
def main():
parser = argparse.ArgumentParser(description="Index data")
parser.add_argument("--host_ip", type=str, default="localhost", help="Host IP")
parser.add_argument("--port", type=int, default=8889, help="Port")
parser.add_argument("--request_type", type=str, default="chat_completion", help="Test type")
args = parser.parse_args()
print(args)
host_ip = args.host_ip
port = args.port
url = "http://{host_ip}:{port}/v1/retrievaltool".format(host_ip=host_ip, port=port)
response = search_knowledge_base("OPEA", url, request_type=args.request_type)
print(response)
if __name__ == "__main__":
main()