Docsum Gateway Fix (#902)

* update gateway

Signed-off-by: Mustafa <mustafa.cetin@intel.com>

* update the gateway

Signed-off-by: Mustafa <mustafa.cetin@intel.com>

* update the gateway

Signed-off-by: Mustafa <mustafa.cetin@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Mustafa <mustafa.cetin@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Mustafa
2024-11-14 19:14:50 -08:00
committed by GitHub
parent 405a632b31
commit d211cb2dbd

View File

@@ -419,12 +419,62 @@ class DocSumGateway(Gateway):
output_datatype=ChatCompletionResponse,
)
async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.model_validate(data)
async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)):
if "application/json" in request.headers.get("content-type"):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.model_validate(data)
prompt = self._handle_message(chat_request.messages)
initial_inputs_data = {data["type"]: prompt}
elif "multipart/form-data" in request.headers.get("content-type"):
data = await request.form()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.model_validate(data)
data_type = data.get("type")
file_summaries = []
if files:
for file in files:
file_path = f"/tmp/{file.filename}"
if data_type is not None and data_type in ["audio", "video"]:
raise ValueError(
"Audio and Video file uploads are not supported in docsum with curl request, please use the UI."
)
else:
import aiofiles
async with aiofiles.open(file_path, "wb") as f:
await f.write(await file.read())
docs = read_text_from_file(file, file_path)
os.remove(file_path)
if isinstance(docs, list):
file_summaries.extend(docs)
else:
file_summaries.append(docs)
if file_summaries:
prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries)
else:
prompt = self._handle_message(chat_request.messages)
data_type = data.get("type")
if data_type is not None:
initial_inputs_data = {}
initial_inputs_data[data_type] = prompt
else:
initial_inputs_data = {"query": prompt}
else:
raise ValueError(f"Unknown request type: {request.headers.get('content-type')}")
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
@@ -434,12 +484,14 @@ class DocSumGateway(Gateway):
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
language=chat_request.language if chat_request.language else "auto",
model=chat_request.model if chat_request.model else None,
language=chat_request.language if chat_request.language else "auto",
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={data["type"]: prompt}, llm_parameters=parameters
initial_inputs=initial_inputs_data, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (