Signed-off-by: minmin-intel <minmin.hou@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit c5177c5e2f)
71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
# Copyright (C) 2024 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import argparse
|
|
|
|
import requests
|
|
|
|
|
|
def search_knowledge_base(query: str, url: str, request_type: str) -> str:
|
|
"""Search the knowledge base for a specific query."""
|
|
print(url)
|
|
proxies = {"http": ""}
|
|
if request_type == "chat_completion":
|
|
print("Sending chat completion request")
|
|
payload = {
|
|
"messages": query,
|
|
"k": 5,
|
|
"top_n": 2,
|
|
}
|
|
else:
|
|
print("Sending textdoc request")
|
|
payload = {
|
|
"text": query,
|
|
}
|
|
response = requests.post(url, json=payload, proxies=proxies)
|
|
print(response)
|
|
print(response.json().keys())
|
|
if "documents" in response.json():
|
|
docs = response.json()["documents"]
|
|
context = ""
|
|
for i, doc in enumerate(docs):
|
|
if i == 0:
|
|
context = str(i) + ": " + doc
|
|
else:
|
|
context += "\n" + str(i) + ": " + doc
|
|
return context
|
|
elif "text" in response.json():
|
|
return response.json()["text"]
|
|
elif "reranked_docs" in response.json():
|
|
docs = response.json()["reranked_docs"]
|
|
context = ""
|
|
for i, doc in enumerate(docs):
|
|
if i == 0:
|
|
context = doc["text"]
|
|
else:
|
|
context += "\n" + doc["text"]
|
|
return context
|
|
else:
|
|
return "Error parsing response from the knowledge base."
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Index data")
|
|
parser.add_argument("--host_ip", type=str, default="localhost", help="Host IP")
|
|
parser.add_argument("--port", type=int, default=8889, help="Port")
|
|
parser.add_argument("--request_type", type=str, default="chat_completion", help="Test type")
|
|
args = parser.parse_args()
|
|
print(args)
|
|
|
|
host_ip = args.host_ip
|
|
port = args.port
|
|
url = "http://{host_ip}:{port}/v1/retrievaltool".format(host_ip=host_ip, port=port)
|
|
|
|
response = search_knowledge_base("OPEA", url, request_type=args.request_type)
|
|
|
|
print(response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|