Refactor DocSum example (#1286)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Sihan Chen
2024-12-26 14:45:17 +08:00
committed by GitHub
parent 6b6a08df78
commit a01729a5c2
16 changed files with 145 additions and 1143 deletions

View File

@@ -2,7 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import os
import subprocess
import uuid
from typing import List
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
@@ -20,8 +23,8 @@ from fastapi.responses import StreamingResponse
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
DATA_SERVICE_HOST_IP = os.getenv("DATA_SERVICE_HOST_IP", "0.0.0.0")
DATA_SERVICE_PORT = int(os.getenv("DATA_SERVICE_PORT", 7079))
ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0")
ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 7066))
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
@@ -29,11 +32,20 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
if self.services[cur_node].service_type == ServiceType.LLM:
for key_to_replace in ["text", "asr_result"]:
if key_to_replace in inputs:
inputs["query"] = inputs[key_to_replace]
del inputs[key_to_replace]
docsum_parameters = kwargs.get("docsum_parameters", None)
if docsum_parameters:
docsum_parameters = docsum_parameters.model_dump()
del docsum_parameters["query"]
inputs.update(docsum_parameters)
elif self.services[cur_node].service_type == ServiceType.ASR:
if "video" in inputs:
audio_base64 = video2audio(inputs["video"])
inputs["audio"] = audio_base64
return inputs
@@ -45,6 +57,44 @@ def read_pdf(file):
return docs
def video2audio(
video_base64: str,
) -> str:
"""Convert a base64 video string to a base64 audio string using ffmpeg.
Args:
video_base64 (str): Base64 encoded video string.
Returns:
str: Base64 encoded audio string.
"""
video_data = base64.b64decode(video_base64)
uid = str(uuid.uuid4())
temp_video_path = f"{uid}.mp4"
temp_audio_path = f"{uid}.mp3"
with open(temp_video_path, "wb") as video_file:
video_file.write(video_data)
try:
subprocess.run(
["ffmpeg", "-i", temp_video_path, "-q:a", "0", "-map", "a", temp_audio_path],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
)
# Read the extracted audio file and encode it to base64
with open(temp_audio_path, "rb") as audio_file:
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
finally:
# Clean up the temporary video file
os.remove(temp_video_path)
os.remove(temp_audio_path)
return audio_base64
def read_text_from_file(file, save_file_name):
import docx2txt
from langchain.text_splitter import CharacterTextSplitter
@@ -78,17 +128,18 @@ class DocSumService:
self.port = port
ServiceOrchestrator.align_inputs = align_inputs
self.megaservice = ServiceOrchestrator()
self.megaservice_text_only = ServiceOrchestrator()
self.endpoint = str(MegaServiceEndpoint.DOC_SUMMARY)
def add_remote_service(self):
data = MicroService(
name="multimedia2text",
host=DATA_SERVICE_HOST_IP,
port=DATA_SERVICE_PORT,
endpoint="/v1/multimedia2text",
asr = MicroService(
name="asr",
host=ASR_SERVICE_HOST_IP,
port=ASR_SERVICE_PORT,
endpoint="/v1/asr",
use_remote_service=True,
service_type=ServiceType.DATAPREP,
service_type=ServiceType.ASR,
)
llm = MicroService(
@@ -100,10 +151,12 @@ class DocSumService:
service_type=ServiceType.LLM,
)
self.megaservice.add(data).add(llm)
self.megaservice.flow_to(data, llm)
self.megaservice.add(asr).add(llm)
self.megaservice.flow_to(asr, llm)
self.megaservice_text_only.add(llm)
async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
"""Accept pure text, or files .txt/.pdf.docx, audio/video base64 string."""
if "application/json" in request.headers.get("content-type"):
data = await request.json()
@@ -129,11 +182,15 @@ class DocSumService:
file_summaries = []
if files:
for file in files:
file_path = f"/tmp/{file.filename}"
# Fix concurrency issue with the same file name
# https://github.com/opea-project/GenAIExamples/issues/1279
uid = str(uuid.uuid4())
file_path = f"/tmp/{uid}"
if data_type is not None and data_type in ["audio", "video"]:
raise ValueError(
"Audio and Video file uploads are not supported in docsum with curl request, please use the UI."
"Audio and Video file uploads are not supported in docsum with curl request, \
please use the UI or pass base64 string of the content directly."
)
else:
@@ -181,19 +238,34 @@ class DocSumService:
chunk_overlap=chunk_overlap,
chunk_size=chunk_size,
)
text_only = "text" in initial_inputs_data
if not text_only:
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs_data, docsum_parameters=docsum_parameters
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs_data, docsum_parameters=docsum_parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
else:
result_dict, runtime_graph = await self.megaservice_text_only.schedule(
initial_inputs=initial_inputs_data, docsum_parameters=docsum_parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []