Refactor DocSum example (#1286)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
118
DocSum/docsum.py
118
DocSum/docsum.py
@@ -2,7 +2,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
|
||||
@@ -20,8 +23,8 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
|
||||
|
||||
DATA_SERVICE_HOST_IP = os.getenv("DATA_SERVICE_HOST_IP", "0.0.0.0")
|
||||
DATA_SERVICE_PORT = int(os.getenv("DATA_SERVICE_PORT", 7079))
|
||||
ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0")
|
||||
ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 7066))
|
||||
|
||||
LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0")
|
||||
LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
|
||||
@@ -29,11 +32,20 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000))
|
||||
|
||||
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
|
||||
if self.services[cur_node].service_type == ServiceType.LLM:
|
||||
for key_to_replace in ["text", "asr_result"]:
|
||||
if key_to_replace in inputs:
|
||||
inputs["query"] = inputs[key_to_replace]
|
||||
del inputs[key_to_replace]
|
||||
|
||||
docsum_parameters = kwargs.get("docsum_parameters", None)
|
||||
if docsum_parameters:
|
||||
docsum_parameters = docsum_parameters.model_dump()
|
||||
del docsum_parameters["query"]
|
||||
inputs.update(docsum_parameters)
|
||||
elif self.services[cur_node].service_type == ServiceType.ASR:
|
||||
if "video" in inputs:
|
||||
audio_base64 = video2audio(inputs["video"])
|
||||
inputs["audio"] = audio_base64
|
||||
return inputs
|
||||
|
||||
|
||||
@@ -45,6 +57,44 @@ def read_pdf(file):
|
||||
return docs
|
||||
|
||||
|
||||
def video2audio(
|
||||
video_base64: str,
|
||||
) -> str:
|
||||
"""Convert a base64 video string to a base64 audio string using ffmpeg.
|
||||
|
||||
Args:
|
||||
video_base64 (str): Base64 encoded video string.
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded audio string.
|
||||
"""
|
||||
video_data = base64.b64decode(video_base64)
|
||||
|
||||
uid = str(uuid.uuid4())
|
||||
temp_video_path = f"{uid}.mp4"
|
||||
temp_audio_path = f"{uid}.mp3"
|
||||
with open(temp_video_path, "wb") as video_file:
|
||||
video_file.write(video_data)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["ffmpeg", "-i", temp_video_path, "-q:a", "0", "-map", "a", temp_audio_path],
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
# Read the extracted audio file and encode it to base64
|
||||
with open(temp_audio_path, "rb") as audio_file:
|
||||
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||
|
||||
finally:
|
||||
# Clean up the temporary video file
|
||||
os.remove(temp_video_path)
|
||||
os.remove(temp_audio_path)
|
||||
|
||||
return audio_base64
|
||||
|
||||
|
||||
def read_text_from_file(file, save_file_name):
|
||||
import docx2txt
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
@@ -78,17 +128,18 @@ class DocSumService:
|
||||
self.port = port
|
||||
ServiceOrchestrator.align_inputs = align_inputs
|
||||
self.megaservice = ServiceOrchestrator()
|
||||
self.megaservice_text_only = ServiceOrchestrator()
|
||||
self.endpoint = str(MegaServiceEndpoint.DOC_SUMMARY)
|
||||
|
||||
def add_remote_service(self):
|
||||
|
||||
data = MicroService(
|
||||
name="multimedia2text",
|
||||
host=DATA_SERVICE_HOST_IP,
|
||||
port=DATA_SERVICE_PORT,
|
||||
endpoint="/v1/multimedia2text",
|
||||
asr = MicroService(
|
||||
name="asr",
|
||||
host=ASR_SERVICE_HOST_IP,
|
||||
port=ASR_SERVICE_PORT,
|
||||
endpoint="/v1/asr",
|
||||
use_remote_service=True,
|
||||
service_type=ServiceType.DATAPREP,
|
||||
service_type=ServiceType.ASR,
|
||||
)
|
||||
|
||||
llm = MicroService(
|
||||
@@ -100,10 +151,12 @@ class DocSumService:
|
||||
service_type=ServiceType.LLM,
|
||||
)
|
||||
|
||||
self.megaservice.add(data).add(llm)
|
||||
self.megaservice.flow_to(data, llm)
|
||||
self.megaservice.add(asr).add(llm)
|
||||
self.megaservice.flow_to(asr, llm)
|
||||
self.megaservice_text_only.add(llm)
|
||||
|
||||
async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
|
||||
"""Accept pure text, or files .txt/.pdf.docx, audio/video base64 string."""
|
||||
|
||||
if "application/json" in request.headers.get("content-type"):
|
||||
data = await request.json()
|
||||
@@ -129,11 +182,15 @@ class DocSumService:
|
||||
file_summaries = []
|
||||
if files:
|
||||
for file in files:
|
||||
file_path = f"/tmp/{file.filename}"
|
||||
# Fix concurrency issue with the same file name
|
||||
# https://github.com/opea-project/GenAIExamples/issues/1279
|
||||
uid = str(uuid.uuid4())
|
||||
file_path = f"/tmp/{uid}"
|
||||
|
||||
if data_type is not None and data_type in ["audio", "video"]:
|
||||
raise ValueError(
|
||||
"Audio and Video file uploads are not supported in docsum with curl request, please use the UI."
|
||||
"Audio and Video file uploads are not supported in docsum with curl request, \
|
||||
please use the UI or pass base64 string of the content directly."
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -181,19 +238,34 @@ class DocSumService:
|
||||
chunk_overlap=chunk_overlap,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
text_only = "text" in initial_inputs_data
|
||||
if not text_only:
|
||||
result_dict, runtime_graph = await self.megaservice.schedule(
|
||||
initial_inputs=initial_inputs_data, docsum_parameters=docsum_parameters
|
||||
)
|
||||
|
||||
result_dict, runtime_graph = await self.megaservice.schedule(
|
||||
initial_inputs=initial_inputs_data, docsum_parameters=docsum_parameters
|
||||
)
|
||||
for node, response in result_dict.items():
|
||||
# Here it suppose the last microservice in the megaservice is LLM.
|
||||
if (
|
||||
isinstance(response, StreamingResponse)
|
||||
and node == list(self.megaservice.services.keys())[-1]
|
||||
and self.megaservice.services[node].service_type == ServiceType.LLM
|
||||
):
|
||||
return response
|
||||
else:
|
||||
result_dict, runtime_graph = await self.megaservice_text_only.schedule(
|
||||
initial_inputs=initial_inputs_data, docsum_parameters=docsum_parameters
|
||||
)
|
||||
|
||||
for node, response in result_dict.items():
|
||||
# Here it suppose the last microservice in the megaservice is LLM.
|
||||
if (
|
||||
isinstance(response, StreamingResponse)
|
||||
and node == list(self.megaservice.services.keys())[-1]
|
||||
and self.megaservice.services[node].service_type == ServiceType.LLM
|
||||
):
|
||||
return response
|
||||
|
||||
for node, response in result_dict.items():
|
||||
# Here it suppose the last microservice in the megaservice is LLM.
|
||||
if (
|
||||
isinstance(response, StreamingResponse)
|
||||
and node == list(self.megaservice.services.keys())[-1]
|
||||
and self.megaservice.services[node].service_type == ServiceType.LLM
|
||||
):
|
||||
return response
|
||||
last_node = runtime_graph.all_leaves()[-1]
|
||||
response = result_dict[last_node]["text"]
|
||||
choices = []
|
||||
|
||||
Reference in New Issue
Block a user