fix bugs in DocIndexRetriever (#1770)

Signed-off-by: minmin-intel <minmin.hou@intel.com>
This commit is contained in:
minmin-intel
2025-04-09 18:45:46 -07:00
committed by GitHub
parent 00d7a65dd8
commit 411bb28f41
12 changed files with 152 additions and 111 deletions

View File

@@ -80,9 +80,22 @@ Example usage:
```python
url = "http://{host_ip}:{port}/v1/retrievaltool".format(host_ip=host_ip, port=port)
payload = {
"messages": query,
"messages": query, # must be a string, this is a required field
"k": 5, # retriever top k
"top_n": 2, # reranker top n
}
response = requests.post(url, json=payload)
```
**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
1. retriever
* search_type: str = "similarity"
* k: int = 4
* distance_threshold: Optional[float] = None
* fetch_k: int = 20
* lambda_mult: float = 0.5
* score_threshold: float = 0.2
2. reranker
* top_n: int = 1

View File

@@ -97,9 +97,6 @@ Retrieval from KnowledgeBase
curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"messages": "Explain the OPEA project?"
}'
# expected output
{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"}
```
**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
@@ -128,7 +125,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati
# embedding microservice
curl http://${host_ip}:6000/v1/embeddings \
-X POST \
-d '{"text":"Explain the OPEA project"}' \
-d '{"messages":"Explain the OPEA project"}' \
-H 'Content-Type: application/json' > query
docker container logs embedding-server

View File

@@ -13,10 +13,11 @@ services:
dataprep-redis-service:
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
container_name: dataprep-redis-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
depends_on:
- redis-vector-db
redis-vector-db:
condition: service_started
tei-embedding-service:
condition: service_healthy
ports:
- "6007:5000"
- "6008:6008"
@@ -28,7 +29,7 @@ services:
REDIS_URL: ${REDIS_URL}
REDIS_HOST: ${REDIS_HOST}
INDEX_NAME: ${INDEX_NAME}
TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
LOGFLAG: ${LOGFLAG}
tei-embedding-service:
@@ -54,8 +55,6 @@ services:
embedding:
image: ${REGISTRY:-opea}/embedding:${TAG:-latest}
container_name: embedding-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/comps
ports:
- "6000:6000"
ipc: host
@@ -114,8 +113,6 @@ services:
reranking:
image: ${REGISTRY:-opea}/reranking:${TAG:-latest}
container_name: reranking-tei-xeon-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
depends_on:
tei-reranking-service:
condition: service_healthy

View File

@@ -87,9 +87,6 @@ Retrieval from KnowledgeBase
curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"messages": "Explain the OPEA project?"
}'
# expected output
{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"}
```
**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
@@ -118,7 +115,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati
# embedding microservice
curl http://${host_ip}:6000/v1/embeddings \
-X POST \
-d '{"text":"Explain the OPEA project"}' \
-d '{"messages":"Explain the OPEA project"}' \
-H 'Content-Type: application/json' > query
docker container logs embedding-server

View File

@@ -15,8 +15,10 @@ services:
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
container_name: dataprep-redis-server
depends_on:
- redis-vector-db
- tei-embedding-service
redis-vector-db:
condition: service_started
tei-embedding-service:
condition: service_healthy
ports:
- "6007:5000"
environment:
@@ -25,7 +27,7 @@ services:
https_proxy: ${https_proxy}
REDIS_URL: ${REDIS_URL}
INDEX_NAME: ${INDEX_NAME}
TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
tei-embedding-service:
image: ghcr.io/huggingface/tei-gaudi:1.5.0
@@ -87,6 +89,8 @@ services:
INDEX_NAME: ${INDEX_NAME}
LOGFLAG: ${LOGFLAG}
RETRIEVER_COMPONENT_NAME: "OPEA_RETRIEVER_REDIS"
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
restart: unless-stopped
tei-reranking-service:
image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.6

View File

@@ -8,9 +8,8 @@ from typing import Union
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest
from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from comps.cores.proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from fastapi import Request
from fastapi.responses import StreamingResponse
MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889)
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
@@ -22,41 +21,75 @@ RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000)
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
print(f"Inputs to {cur_node}: {inputs}")
print(f"*** Inputs to {cur_node}:\n{inputs}")
print("--" * 50)
for key, value in kwargs.items():
print(f"{key}: {value}")
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
inputs["input"] = inputs["text"]
del inputs["text"]
elif self.services[cur_node].service_type == ServiceType.RETRIEVER:
# input is EmbedDoc
"""Class EmbedDoc(BaseDoc):
text: Union[str, List[str]]
embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]]
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2
constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None
index_name: Optional[str] = None
"""
# prepare the retriever params
retriever_parameters = kwargs.get("retriever_parameters", None)
if retriever_parameters:
inputs.update(retriever_parameters.dict())
elif self.services[cur_node].service_type == ServiceType.RERANK:
# input is SearchedDoc
"""Class SearchedDoc(BaseDoc):
retrieved_docs: DocList[TextDoc]
initial_query: str
top_n: int = 1
"""
# prepare the reranker params
reranker_parameters = kwargs.get("reranker_parameters", None)
if reranker_parameters:
inputs.update(reranker_parameters.dict())
print(f"*** Formatted Inputs to {cur_node}:\n{inputs}")
print("--" * 50)
return inputs
def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
next_data = {}
print(f"*** Direct Outputs from {cur_node}:\n{data}")
print("--" * 50)
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
# turn into chat completion request
# next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
print("Assembing output from Embedding for next node...")
print("Inputs to Embedding: ", inputs)
print("Keyword arguments: ")
for key, value in kwargs.items():
print(f"{key}: {value}")
next_data = {
"input": inputs["input"],
"messages": inputs["input"],
"embedding": [item["embedding"] for item in data["data"]],
"k": kwargs["k"] if "k" in kwargs else 4,
"search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity",
"distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None,
"fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20,
"lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5,
"score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2,
"top_n": kwargs["top_n"] if "top_n" in kwargs else 1,
}
print("Output from Embedding for next node:\n", next_data)
# direct output from Embedding microservice is EmbeddingResponse
"""
class EmbeddingResponse(BaseModel):
object: str = "list"
model: Optional[str] = None
data: List[EmbeddingResponseData]
usage: Optional[UsageInfo] = None
class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]
"""
# turn it into EmbedDoc
assert isinstance(data["data"], list)
next_data = {"text": inputs["input"], "embedding": data["data"][0]["embedding"]} # EmbedDoc
else:
next_data = data
print(f"*** Formatted Output from {cur_node} for next node:\n", next_data)
print("--" * 50)
return next_data
@@ -100,54 +133,41 @@ class RetrievalToolService:
self.megaservice.flow_to(retriever, rerank)
async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
chat_request = None
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query, chat_request
data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
if chat_request is None:
raise ValueError(f"Unknown request type: {data}")
chat_request = ChatCompletionRequest.parse_obj(data)
if isinstance(chat_request, ChatCompletionRequest):
initial_inputs = {
"messages": query,
"input": query, # has to be input due to embedding expects either input or text
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}
prompt = chat_request.messages
kwargs = {
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs,
**kwargs,
)
else:
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query})
# dummy llm params
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
model=chat_request.model if chat_request.model else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt},
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]

View File

@@ -2,16 +2,17 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any
import requests
def search_knowledge_base(query: str) -> str:
def search_knowledge_base(query: str, args: Any) -> str:
"""Search the knowledge base for a specific query."""
url = os.environ.get("RETRIEVAL_TOOL_URL")
url = os.environ.get("RETRIEVAL_TOOL_URL", "http://localhost:8889/v1/retrievaltool")
print(url)
proxies = {"http": ""}
payload = {"messages": query, "k": 5, "top_n": 2}
payload = {"messages": query, "k": args.k, "top_n": args.top_n}
response = requests.post(url, json=payload, proxies=proxies)
print(response)
if "documents" in response.json():
@@ -33,6 +34,16 @@ def search_knowledge_base(query: str) -> str:
if __name__ == "__main__":
resp = search_knowledge_base("What is OPEA?")
# resp = search_knowledge_base("Thriller")
import argparse
parser = argparse.ArgumentParser(description="Test the knowledge base search.")
parser.add_argument("--k", type=int, default=5, help="retriever top k")
parser.add_argument("--top_n", type=int, default=2, help="reranker top n")
args = parser.parse_args()
resp = search_knowledge_base("What is OPEA?", args)
print(resp)
if not resp.startswith("Error"):
print("Test successful!")

View File

@@ -88,10 +88,10 @@ function validate_megaservice() {
fi
# Curl the Mega Service
echo "================Testing retriever service: Text Request ================"
echo "================Testing retriever service ================"
cd $WORKPATH/tests
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"text": "Explain the OPEA project?"
"messages": "Explain the OPEA project?"
}')
local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-gaudi")

View File

@@ -87,10 +87,10 @@ function validate_megaservice() {
fi
# Curl the Mega Service
echo "================Testing retriever service: Text Request ================"
echo "================Testing retriever service ================"
cd $WORKPATH/tests
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"text": "Explain the OPEA project?"
"messages": "Explain the OPEA project?"
}')
local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-xeon")

View File

@@ -38,7 +38,6 @@ function start_services() {
export RERANK_MODEL_ID="BAAI/bge-reranker-base"
export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:8090"
export TEI_RERANKING_ENDPOINT="http://${ip_address}:8808"
export TGI_LLM_ENDPOINT="http://${ip_address}:8008"
export REDIS_URL="redis://${ip_address}:6379"
export INDEX_NAME="rag-redis"
export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
@@ -46,14 +45,13 @@ function start_services() {
export EMBEDDING_SERVICE_HOST_IP=${ip_address}
export RETRIEVER_SERVICE_HOST_IP=${ip_address}
export RERANK_SERVICE_HOST_IP=${ip_address}
export LLM_SERVICE_HOST_IP=${ip_address}
export host_ip=${ip_address}
export RERANK_TYPE="tei"
export LOGFLAG=true
# Start Docker Containers
docker compose up -d
sleep 30
sleep 1m
echo "Docker services started!"
}
@@ -86,11 +84,13 @@ function validate_megaservice() {
fi
# Curl the Mega Service
echo "==============Testing retriever service: Text Request================="
local CONTENT=$(curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"text": "Explain the OPEA project?"
echo "==============Testing retriever service================="
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"messages": "Explain the OPEA project?"
}')
local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-gaudi")
echo "$EXIT_CODE"
local EXIT_CODE="${EXIT_CODE:0-1}"
echo "return value is $EXIT_CODE"

View File

@@ -53,7 +53,7 @@ function start_services() {
# Start Docker Containers
docker compose up -d
sleep 5m
sleep 1m
echo "Docker services started!"
}
@@ -86,10 +86,11 @@ function validate_megaservice() {
fi
# Curl the Mega Service
echo "================Testing retriever service: Text Request ================"
echo "================Testing retriever service ================"
cd $WORKPATH/tests
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"text": "Explain the OPEA project?"
"messages": "Explain the OPEA project?"
}')
local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-xeon")
@@ -128,6 +129,7 @@ function main() {
if [[ "$IMAGE_REPO" == "opea" ]]; then build_docker_images; fi
echo "Dump current docker ps"
docker ps
start_time=$(date +%s)
start_services
end_time=$(date +%s)

View File

@@ -80,10 +80,10 @@ function validate_megaservice() {
fi
# Curl the Mega Service
echo "================Testing retriever service: Text Request ================"
echo "================Testing retriever service ================"
cd $WORKPATH/tests
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"text": "Explain the OPEA project?"
"messages": "Explain the OPEA project?"
}')
local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-xeon")