Support multiple image sources for LVM microservice (#451)
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
|
||||
from ..proto.api_protocol import (
|
||||
AudioChatCompletionRequest,
|
||||
@@ -74,6 +77,7 @@ class Gateway:
|
||||
pass
|
||||
|
||||
def _handle_message(self, messages):
|
||||
images = []
|
||||
if isinstance(messages, str):
|
||||
prompt = messages
|
||||
else:
|
||||
@@ -104,7 +108,6 @@ class Gateway:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
if system_prompt:
|
||||
prompt = system_prompt + "\n"
|
||||
images = []
|
||||
for role, message in messages_dict.items():
|
||||
if isinstance(message, tuple):
|
||||
text, image_list = message
|
||||
@@ -113,8 +116,24 @@ class Gateway:
|
||||
else:
|
||||
prompt += role + ":"
|
||||
for img in image_list:
|
||||
response = requests.get(img)
|
||||
images.append(base64.b64encode(response.content).decode("utf-8"))
|
||||
# URL
|
||||
if img.startswith("http://") or img.startswith("https://"):
|
||||
response = requests.get(img)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGBA")
|
||||
image_bytes = BytesIO()
|
||||
image.save(image_bytes, format="PNG")
|
||||
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
|
||||
# Local Path
|
||||
elif os.path.exists(img):
|
||||
image = Image.open(img).convert("RGBA")
|
||||
image_bytes = BytesIO()
|
||||
image.save(image_bytes, format="PNG")
|
||||
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
|
||||
# Bytes
|
||||
else:
|
||||
img_b64_str = img
|
||||
|
||||
images.append(img_b64_str)
|
||||
else:
|
||||
if message:
|
||||
prompt += role + ": " + message + "\n"
|
||||
|
||||
@@ -48,7 +48,7 @@ async def lvm(request: LVMDoc):
|
||||
async def stream_generator():
|
||||
chat_response = ""
|
||||
text_generation = await lvm_client.text_generation(
|
||||
prompt=prompt,
|
||||
prompt=image_prompt,
|
||||
stream=streaming,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
|
||||
@@ -5,6 +5,7 @@ httpx
|
||||
opentelemetry-api
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-sdk
|
||||
Pillow
|
||||
prometheus-fastapi-instrumentator
|
||||
pyyaml
|
||||
requests
|
||||
|
||||
Reference in New Issue
Block a user