Update AgentQnA and DocIndexRetriever (#1564)

Signed-off-by: minmin-intel <minmin.hou@intel.com>
This commit is contained in:
minmin-intel
2025-02-21 17:51:26 -08:00
committed by GitHub
parent caec354324
commit a7eced4161
11 changed files with 172 additions and 120 deletions

View File

@@ -13,6 +13,8 @@ services:
dataprep-redis-service:
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
container_name: dataprep-redis-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
depends_on:
- redis-vector-db
ports:
@@ -52,6 +54,8 @@ services:
embedding:
image: ${REGISTRY:-opea}/embedding:${TAG:-latest}
container_name: embedding-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/comps
ports:
- "6000:6000"
ipc: host
@@ -110,6 +114,8 @@ services:
reranking:
image: ${REGISTRY:-opea}/reranking:${TAG:-latest}
container_name: reranking-tei-xeon-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
depends_on:
tei-reranking-service:
condition: service_healthy

View File

@@ -22,16 +22,38 @@ RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000)
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
inputs["input"] = inputs["text"]
del inputs["text"]
print(f"Inputs to {cur_node}: {inputs}")
for key, value in kwargs.items():
print(f"{key}: {value}")
return inputs
def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
next_data = {}
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
# turn into chat completion request
# next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
print("Assembing output from Embedding for next node...")
print("Inputs to Embedding: ", inputs)
print("Keyword arguments: ")
for key, value in kwargs.items():
print(f"{key}: {value}")
next_data = {
"input": inputs["input"],
"messages": inputs["input"],
"embedding": data, # [item["embedding"] for item in data["data"]],
"k": kwargs["k"] if "k" in kwargs else 4,
"search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity",
"distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None,
"fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20,
"lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5,
"score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2,
"top_n": kwargs["top_n"] if "top_n" in kwargs else 1,
}
print("Output from Embedding for next node:\n", next_data)
else:
next_data = data
@@ -99,18 +121,6 @@ class RetrievalToolService:
raise ValueError(f"Unknown request type: {data}")
if isinstance(chat_request, ChatCompletionRequest):
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
initial_inputs = {
"messages": query,
"input": query, # has to be input due to embedding expects either input or text
@@ -123,13 +133,21 @@ class RetrievalToolService:
"top_n": chat_request.top_n if chat_request.top_n else 1,
}
kwargs = {
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
**kwargs,
)
else:
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query})
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]

View File

@@ -0,0 +1,38 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import requests
def search_knowledge_base(query: str) -> str:
"""Search the knowledge base for a specific query."""
url = os.environ.get("RETRIEVAL_TOOL_URL")
print(url)
proxies = {"http": ""}
payload = {"messages": query, "k": 5, "top_n": 2}
response = requests.post(url, json=payload, proxies=proxies)
print(response)
if "documents" in response.json():
docs = response.json()["documents"]
context = ""
for i, doc in enumerate(docs):
context += f"Doc[{i+1}]:\n{doc}\n"
return context
elif "text" in response.json():
return response.json()["text"]
elif "reranked_docs" in response.json():
docs = response.json()["reranked_docs"]
context = ""
for i, doc in enumerate(docs):
context += f"Doc[{i+1}]:\n{doc}\n"
return context
else:
return "Error parsing response from the knowledge base."
if __name__ == "__main__":
resp = search_knowledge_base("What is OPEA?")
# resp = search_knowledge_base("Thriller")
print(resp)