Update AgentQnA and DocIndexRetriever (#1564)
Signed-off-by: minmin-intel <minmin.hou@intel.com>
This commit is contained in:
@@ -13,6 +13,8 @@ services:
|
||||
dataprep-redis-service:
|
||||
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
|
||||
container_name: dataprep-redis-server
|
||||
# volumes:
|
||||
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
|
||||
depends_on:
|
||||
- redis-vector-db
|
||||
ports:
|
||||
@@ -52,6 +54,8 @@ services:
|
||||
embedding:
|
||||
image: ${REGISTRY:-opea}/embedding:${TAG:-latest}
|
||||
container_name: embedding-server
|
||||
# volumes:
|
||||
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/comps
|
||||
ports:
|
||||
- "6000:6000"
|
||||
ipc: host
|
||||
@@ -110,6 +114,8 @@ services:
|
||||
reranking:
|
||||
image: ${REGISTRY:-opea}/reranking:${TAG:-latest}
|
||||
container_name: reranking-tei-xeon-server
|
||||
# volumes:
|
||||
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
|
||||
depends_on:
|
||||
tei-reranking-service:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -22,16 +22,38 @@ RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000)
|
||||
|
||||
|
||||
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
|
||||
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
|
||||
inputs["input"] = inputs["text"]
|
||||
del inputs["text"]
|
||||
print(f"Inputs to {cur_node}: {inputs}")
|
||||
for key, value in kwargs.items():
|
||||
print(f"{key}: {value}")
|
||||
return inputs
|
||||
|
||||
|
||||
def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
|
||||
next_data = {}
|
||||
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
|
||||
next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
|
||||
# turn into chat completion request
|
||||
# next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
|
||||
print("Assembing output from Embedding for next node...")
|
||||
print("Inputs to Embedding: ", inputs)
|
||||
print("Keyword arguments: ")
|
||||
for key, value in kwargs.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
next_data = {
|
||||
"input": inputs["input"],
|
||||
"messages": inputs["input"],
|
||||
"embedding": data, # [item["embedding"] for item in data["data"]],
|
||||
"k": kwargs["k"] if "k" in kwargs else 4,
|
||||
"search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity",
|
||||
"distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None,
|
||||
"fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20,
|
||||
"lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5,
|
||||
"score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2,
|
||||
"top_n": kwargs["top_n"] if "top_n" in kwargs else 1,
|
||||
}
|
||||
|
||||
print("Output from Embedding for next node:\n", next_data)
|
||||
|
||||
else:
|
||||
next_data = data
|
||||
|
||||
@@ -99,18 +121,6 @@ class RetrievalToolService:
|
||||
raise ValueError(f"Unknown request type: {data}")
|
||||
|
||||
if isinstance(chat_request, ChatCompletionRequest):
|
||||
retriever_parameters = RetrieverParms(
|
||||
search_type=chat_request.search_type if chat_request.search_type else "similarity",
|
||||
k=chat_request.k if chat_request.k else 4,
|
||||
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
|
||||
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
|
||||
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
|
||||
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
|
||||
)
|
||||
reranker_parameters = RerankerParms(
|
||||
top_n=chat_request.top_n if chat_request.top_n else 1,
|
||||
)
|
||||
|
||||
initial_inputs = {
|
||||
"messages": query,
|
||||
"input": query, # has to be input due to embedding expects either input or text
|
||||
@@ -123,13 +133,21 @@ class RetrievalToolService:
|
||||
"top_n": chat_request.top_n if chat_request.top_n else 1,
|
||||
}
|
||||
|
||||
kwargs = {
|
||||
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
|
||||
"k": chat_request.k if chat_request.k else 4,
|
||||
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
|
||||
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
|
||||
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
|
||||
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
|
||||
"top_n": chat_request.top_n if chat_request.top_n else 1,
|
||||
}
|
||||
result_dict, runtime_graph = await self.megaservice.schedule(
|
||||
initial_inputs=initial_inputs,
|
||||
retriever_parameters=retriever_parameters,
|
||||
reranker_parameters=reranker_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
|
||||
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query})
|
||||
|
||||
last_node = runtime_graph.all_leaves()[-1]
|
||||
response = result_dict[last_node]
|
||||
|
||||
38
DocIndexRetriever/tests/test.py
Normal file
38
DocIndexRetriever/tests/test.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def search_knowledge_base(query: str) -> str:
|
||||
"""Search the knowledge base for a specific query."""
|
||||
url = os.environ.get("RETRIEVAL_TOOL_URL")
|
||||
print(url)
|
||||
proxies = {"http": ""}
|
||||
payload = {"messages": query, "k": 5, "top_n": 2}
|
||||
response = requests.post(url, json=payload, proxies=proxies)
|
||||
print(response)
|
||||
if "documents" in response.json():
|
||||
docs = response.json()["documents"]
|
||||
context = ""
|
||||
for i, doc in enumerate(docs):
|
||||
context += f"Doc[{i+1}]:\n{doc}\n"
|
||||
return context
|
||||
elif "text" in response.json():
|
||||
return response.json()["text"]
|
||||
elif "reranked_docs" in response.json():
|
||||
docs = response.json()["reranked_docs"]
|
||||
context = ""
|
||||
for i, doc in enumerate(docs):
|
||||
context += f"Doc[{i+1}]:\n{doc}\n"
|
||||
return context
|
||||
else:
|
||||
return "Error parsing response from the knowledge base."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resp = search_knowledge_base("What is OPEA?")
|
||||
# resp = search_knowledge_base("Thriller")
|
||||
print(resp)
|
||||
Reference in New Issue
Block a user