Refactor reranking (#1113)
Signed-off-by: WenjiaoYue <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> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ZePan110 <ze.pan@intel.com>
This commit is contained in:
@@ -244,7 +244,7 @@ curl http://${your_ip}:8015/v1/finetune/list_checkpoints -X POST -H "Content-Typ
|
||||
|
||||
### 3.4 Leverage fine-tuned model
|
||||
|
||||
After fine-tuning job is done, fine-tuned model can be chosen from listed checkpoints, then the fine-tuned model can be used in other microservices. For example, fine-tuned reranking model can be used in [reranks](../reranks/src/README.md) microservice by assign its path to the environment variable `RERANK_MODEL_ID`, fine-tuned embedding model can be used in [embeddings](../embeddings/src/README.md) microservice by assign its path to the environment variable `model`, LLMs after instruction tuning can be used in [llms](../llms/src/text-generation/README.md) microservice by assign its path to the environment variable `your_hf_llm_model`.
|
||||
After fine-tuning job is done, fine-tuned model can be chosen from listed checkpoints, then the fine-tuned model can be used in other microservices. For example, fine-tuned reranking model can be used in [rerankings](../rerankings/src/README.md) microservice by assign its path to the environment variable `RERANK_MODEL_ID`, fine-tuned embedding model can be used in [embeddings](../embeddings/src/README.md) microservice by assign its path to the environment variable `model`, LLMs after instruction tuning can be used in [llms](../llms/src/text-generation/README.md) microservice by assign its path to the environment variable `your_hf_llm_model`.
|
||||
|
||||
## 🚀4. Descriptions for Finetuning parameters
|
||||
|
||||
|
||||
@@ -4238,7 +4238,7 @@ def _ranking_fast(
|
||||
alpha: float,
|
||||
beam_width: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
|
||||
"""Rerankings the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
|
||||
in the paper "A Contrastive Framework for Neural Text Generation".
|
||||
|
||||
Returns the index of the best candidate for each
|
||||
|
||||
@@ -29,6 +29,7 @@ services:
|
||||
http_proxy: ${http_proxy}
|
||||
https_proxy: ${https_proxy}
|
||||
TEI_RERANKING_ENDPOINT: ${TEI_RERANKING_ENDPOINT}
|
||||
RERANK_COMPONENT_NAME: "OPEA_TEI_RERANKING"
|
||||
HF_TOKEN: ${HF_TOKEN}
|
||||
depends_on:
|
||||
tei_reranking_service:
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
services:
|
||||
reranking:
|
||||
image: opea/reranking:latest
|
||||
container_name: reranking-videoqna-server
|
||||
ports:
|
||||
- "8000:8000"
|
||||
ipc: host
|
||||
environment:
|
||||
no_proxy: ${no_proxy}
|
||||
http_proxy: ${http_proxy}
|
||||
https_proxy: ${https_proxy}
|
||||
CHUNK_DURATION: ${CHUNK_DURATION}
|
||||
FILE_SERVER_ENDPOINT: ${FILE_SERVER_ENDPOINT}
|
||||
RERANK_COMPONENT_NAME: "OPEA_VIDEO_RERANKING"
|
||||
restart: unless-stopped
|
||||
|
||||
networks:
|
||||
default:
|
||||
driver: bridge
|
||||
46
comps/rerankings/src/Dockerfile
Normal file
46
comps/rerankings/src/Dockerfile
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
|
||||
ARG ARCH="cpu"
|
||||
ARG SERVICE="all"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
|
||||
git \
|
||||
libgl1-mesa-glx \
|
||||
libjemalloc-dev
|
||||
|
||||
RUN useradd -m -s /bin/bash user && \
|
||||
mkdir -p /home/user && \
|
||||
chown -R user /home/user/
|
||||
|
||||
USER user
|
||||
|
||||
COPY comps /home/user/comps
|
||||
|
||||
RUN if [ ${ARCH} = "cpu" ]; then \
|
||||
pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu; \
|
||||
fi && \
|
||||
if [ ${SERVICE} = "videoqna" ]; then \
|
||||
pip install --no-cache-dir --upgrade pip setuptools && \
|
||||
pip install --no-cache-dir -r /home/user/comps/rerankings/src/requirements_videoqna.txt; \
|
||||
elif [ ${SERVICE} = "all" ]; then \
|
||||
git clone https://github.com/IntelLabs/fastRAG.git /home/user/fastRAG && \
|
||||
cd /home/user/fastRAG && \
|
||||
pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir . && \
|
||||
pip install --no-cache-dir .[intel] && \
|
||||
pip install --no-cache-dir -r /home/user/comps/rerankings/src/requirements_videoqna.txt; \
|
||||
fi && \
|
||||
pip install --no-cache-dir --upgrade pip setuptools && \
|
||||
pip install --no-cache-dir -r /home/user/comps/rerankings/src/requirements.txt;
|
||||
|
||||
|
||||
ENV PYTHONPATH=$PYTHONPATH:/home/user
|
||||
|
||||
WORKDIR /home/user/comps/rerankings/src
|
||||
|
||||
ENTRYPOINT ["python", "opea_reranking_microservice.py"]
|
||||
|
Before Width: | Height: | Size: 106 KiB After Width: | Height: | Size: 106 KiB |
@@ -18,7 +18,7 @@ from comps.cores.proto.api_protocol import (
|
||||
RerankingResponseData,
|
||||
)
|
||||
|
||||
logger = CustomLogger("reranking_tei")
|
||||
logger = CustomLogger("tei_reranking")
|
||||
logflag = os.getenv("LOGFLAG", False)
|
||||
|
||||
# Environment variables
|
||||
@@ -27,8 +27,8 @@ CLIENTID = os.getenv("CLIENTID")
|
||||
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
|
||||
|
||||
|
||||
@OpeaComponentRegistry.register("OPEA_RERANK_TEI")
|
||||
class OPEATEIReranking(OpeaComponent):
|
||||
@OpeaComponentRegistry.register("OPEA_TEI_RERANKING")
|
||||
class OpeaTEIReranking(OpeaComponent):
|
||||
"""A specialized reranking component derived from OpeaComponent for TEI reranking services.
|
||||
|
||||
Attributes:
|
||||
121
comps/rerankings/src/integrations/videoqna.py
Normal file
121
comps/rerankings/src/integrations/videoqna.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from comps import CustomLogger, LVMVideoDoc, OpeaComponentRegistry, SearchedMultimodalDoc, ServiceType
|
||||
from comps.cores.common.component import OpeaComponent
|
||||
|
||||
logger = CustomLogger("video_reranking")
|
||||
logflag = os.getenv("LOGFLAG", False)
|
||||
|
||||
chunk_duration = os.getenv("CHUNK_DURATION", "10") or "10"
|
||||
chunk_duration = float(chunk_duration) if chunk_duration.isdigit() else 10.0
|
||||
|
||||
file_server_endpoint = os.getenv("FILE_SERVER_ENDPOINT") or "http://0.0.0.0:6005"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(levelname)s: [%(asctime)s] %(message)s", datefmt="%d/%m/%Y %I:%M:%S"
|
||||
)
|
||||
|
||||
|
||||
def get_top_doc(top_n, videos) -> list:
|
||||
hit_score = {}
|
||||
if videos is None:
|
||||
return None
|
||||
for video_name in videos:
|
||||
try:
|
||||
if video_name not in hit_score.keys():
|
||||
hit_score[video_name] = 0
|
||||
hit_score[video_name] += 1
|
||||
except KeyError as r:
|
||||
logging.info(f"no video name {r}")
|
||||
|
||||
x = dict(sorted(hit_score.items(), key=lambda item: -item[1])) # sorted dict of video name and score
|
||||
top_n_names = list(x.keys())[:top_n]
|
||||
logging.info(f"top docs = {x}")
|
||||
logging.info(f"top n docs names = {top_n_names}")
|
||||
|
||||
return top_n_names
|
||||
|
||||
|
||||
def find_timestamp_from_video(metadata_list, video):
|
||||
return next(
|
||||
(metadata["timestamp"] for metadata in metadata_list if metadata["video"] == video),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def format_video_name(video_name):
|
||||
# Check for an existing file extension
|
||||
match = re.search(r"\.(\w+)$", video_name)
|
||||
|
||||
if match:
|
||||
extension = match.group(1)
|
||||
# If the extension is not 'mp4', raise an error
|
||||
if extension != "mp4":
|
||||
raise ValueError(f"Invalid file extension: .{extension}. Only '.mp4' is allowed.")
|
||||
|
||||
# Use regex to remove any suffix after the base name (e.g., '_interval_0', etc.)
|
||||
base_name = re.sub(r"(_interval_\d+)?(\.mp4)?$", "", video_name)
|
||||
|
||||
# Add the '.mp4' extension
|
||||
formatted_name = f"{base_name}.mp4"
|
||||
|
||||
return formatted_name
|
||||
|
||||
|
||||
@OpeaComponentRegistry.register("OPEA_VIDEO_RERANKING")
|
||||
class OpeaVideoReranking(OpeaComponent):
|
||||
"""A specialized reranking component derived from OpeaComponent for OPEA Video native reranking services."""
|
||||
|
||||
def __init__(self, name: str, description: str, config: dict = None):
|
||||
super().__init__(name, ServiceType.RERANK.name.lower(), description, config)
|
||||
|
||||
async def invoke(self, input: SearchedMultimodalDoc) -> LVMVideoDoc:
|
||||
"""Invokes the reranking service to generate reranking for the provided input.
|
||||
|
||||
Args:
|
||||
input (SearchedMultimodalDoc): The input in OpenAI reranking format.
|
||||
|
||||
Returns:
|
||||
LVMVideoDoc: The response in OpenAI reranking format.
|
||||
"""
|
||||
try:
|
||||
# get top video name from metadata
|
||||
video_names = [meta["video"] for meta in input.metadata]
|
||||
top_video_names = get_top_doc(input.top_n, video_names)
|
||||
|
||||
# only use the first top video
|
||||
timestamp = find_timestamp_from_video(input.metadata, top_video_names[0])
|
||||
formatted_video_name = format_video_name(top_video_names[0])
|
||||
video_url = f"{file_server_endpoint.rstrip('/')}/{formatted_video_name}"
|
||||
|
||||
result = LVMVideoDoc(
|
||||
video_url=video_url,
|
||||
prompt=input.initial_query,
|
||||
chunk_start=timestamp,
|
||||
chunk_duration=float(chunk_duration),
|
||||
max_new_tokens=512,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error in reranking: {str(e)}")
|
||||
# Handle any other exceptions with a generic server error response
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
||||
|
||||
return result
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""Checks the health of the reranking service.
|
||||
|
||||
Returns:
|
||||
bool: True if the service is reachable and healthy, False otherwise.
|
||||
"""
|
||||
|
||||
return True
|
||||
@@ -5,7 +5,8 @@ import os
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from integrations.opea_tei import OPEATEIReranking
|
||||
from integrations.tei import OpeaTEIReranking
|
||||
from integrations.videoqna import OpeaVideoReranking
|
||||
|
||||
from comps import (
|
||||
CustomLogger,
|
||||
@@ -22,7 +23,7 @@ from comps.cores.proto.docarray import LLMParamsDoc, LVMVideoDoc, RerankedDoc, S
|
||||
logger = CustomLogger("opea_reranking_microservice")
|
||||
logflag = os.getenv("LOGFLAG", False)
|
||||
|
||||
rerank_component_name = os.getenv("RERANK_COMPONENT_NAME", "OPEA_RERANK_TEI")
|
||||
rerank_component_name = os.getenv("RERANK_COMPONENT_NAME", "OPEA_TEI_RERANKING")
|
||||
# Initialize OpeaComponentLoader
|
||||
loader = OpeaComponentLoader(rerank_component_name, description=f"OPEA RERANK Component: {rerank_component_name}")
|
||||
|
||||
7
comps/rerankings/src/requirements_videoqna.txt
Normal file
7
comps/rerankings/src/requirements_videoqna.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
datasets
|
||||
haystack-ai
|
||||
langchain --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
langchain_community --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
openai
|
||||
Pillow
|
||||
pydub
|
||||
@@ -1,30 +0,0 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
|
||||
ARG ARCH="cpu"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
|
||||
libgl1-mesa-glx \
|
||||
libjemalloc-dev
|
||||
|
||||
RUN useradd -m -s /bin/bash user && \
|
||||
mkdir -p /home/user && \
|
||||
chown -R user /home/user/
|
||||
|
||||
USER user
|
||||
|
||||
COPY comps /home/user/comps
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade pip setuptools && \
|
||||
if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu; fi && \
|
||||
pip install --no-cache-dir -r /home/user/comps/reranks/src/requirements.txt
|
||||
|
||||
ENV PYTHONPATH=$PYTHONPATH:/home/user
|
||||
|
||||
WORKDIR /home/user/comps/reranks/src
|
||||
|
||||
ENTRYPOINT ["python", "opea_reranking_microservice.py"]
|
||||
@@ -30,4 +30,4 @@ ENV PYTHONPATH=$PYTHONPATH:/home/user
|
||||
|
||||
WORKDIR /home/user/comps/text2sql/src/
|
||||
|
||||
ENTRYPOINT ["python", "opea_text2sql_microservice.py"]
|
||||
ENTRYPOINT ["python", "opea_text2sql_microservice.py"]
|
||||
Reference in New Issue
Block a user