* remove examples gateway. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove gateway. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine service code. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update http_service.py * remove gateway ut. * remove gateway ut. * fix conflict service name. * Update http_service.py * add handle message ut. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove `multiprocessing.Process` start server code. * fix ut. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove multiprocessing and enhance ut for coverage. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: chen, suyue <suyue.chen@intel.com>
334 lines
11 KiB
Python
334 lines
11 KiB
Python
# Copyright (C) 2024 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import base64
|
|
import ipaddress
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
from io import BytesIO
|
|
from socket import AF_INET, SOCK_STREAM, socket
|
|
from typing import List, Optional, Union
|
|
|
|
import requests
|
|
from PIL import Image
|
|
|
|
from .logger import CustomLogger
|
|
|
|
|
|
def is_port_free(host: str, port: int) -> bool:
|
|
"""Check if a given port on a host is free.
|
|
|
|
:param host: The host to check.
|
|
:param port: The port to check.
|
|
:return: True if the port is free, False otherwise.
|
|
"""
|
|
with socket(AF_INET, SOCK_STREAM) as session:
|
|
return session.connect_ex((host, port)) != 0
|
|
|
|
|
|
def check_ports_availability(host: Union[str, List[str]], port: Union[int, List[int]]) -> bool:
|
|
"""Check if one or more ports on one or more hosts are free.
|
|
|
|
:param host: The host(s) to check.
|
|
:param port: The port(s) to check.
|
|
:return: True if all ports on all hosts are free, False otherwise.
|
|
"""
|
|
hosts = [host] if isinstance(host, str) else host
|
|
ports = [port] if isinstance(port, int) else port
|
|
|
|
return all(is_port_free(h, p) for h in hosts for p in ports)
|
|
|
|
|
|
def get_internal_ip():
|
|
"""Return the private IP address of the gateway in the same network.
|
|
|
|
:return: Private IP address.
|
|
"""
|
|
import socket
|
|
|
|
ip = "127.0.0.1"
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
s.connect(("10.255.255.255", 1))
|
|
ip = s.getsockname()[0]
|
|
except Exception:
|
|
pass
|
|
return ip
|
|
|
|
|
|
def get_public_ip(timeout: float = 0.3):
|
|
"""Return the public IP address of the gateway in the public network."""
|
|
import urllib.request
|
|
|
|
def _get_public_ip(url):
|
|
try:
|
|
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
|
with urllib.request.urlopen(req, timeout=timeout) as fp:
|
|
_ip = fp.read().decode().strip()
|
|
return _ip
|
|
|
|
except:
|
|
pass
|
|
|
|
ip_lookup_services = [
|
|
"https://api.ipify.org",
|
|
"https://ident.me",
|
|
"https://checkip.amazonaws.com/",
|
|
]
|
|
|
|
for _, url in enumerate(ip_lookup_services):
|
|
ip = _get_public_ip(url)
|
|
if ip:
|
|
return ip
|
|
|
|
|
|
def typename(obj):
|
|
"""Get the typename of object."""
|
|
if not isinstance(obj, type):
|
|
obj = obj.__class__
|
|
try:
|
|
return f"{obj.__module__}.{obj.__name__}"
|
|
except AttributeError:
|
|
return str(obj)
|
|
|
|
|
|
def get_event(obj) -> multiprocessing.Event:
|
|
if isinstance(obj, multiprocessing.Process) or isinstance(obj, multiprocessing.context.ForkProcess):
|
|
return multiprocessing.Event()
|
|
elif isinstance(obj, multiprocessing.context.SpawnProcess):
|
|
return multiprocessing.get_context("spawn").Event()
|
|
else:
|
|
raise TypeError(f'{obj} is not an instance of "multiprocessing.Process"')
|
|
|
|
|
|
def in_docker():
|
|
"""Checks if the current process is running inside Docker."""
|
|
path = "/proc/self/cgroup"
|
|
if os.path.exists("/.dockerenv"):
|
|
return True
|
|
if os.path.isfile(path):
|
|
with open(path, encoding="utf-8") as file:
|
|
return any("docker" in line for line in file)
|
|
return False
|
|
|
|
|
|
def host_is_local(hostname):
|
|
"""Check if hostname is point to localhost."""
|
|
import socket
|
|
|
|
fqn = socket.getfqdn(hostname)
|
|
if fqn in ("localhost", "0.0.0.0") or hostname == "0.0.0.0":
|
|
return True
|
|
|
|
try:
|
|
return ipaddress.ip_address(hostname).is_loopback
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
assigned_ports = set()
|
|
unassigned_ports = []
|
|
DEFAULT_MIN_PORT = 49153
|
|
MAX_PORT = 65535
|
|
|
|
|
|
def reset_ports():
|
|
def _get_unassigned_ports():
|
|
# if we are running out of ports, lower default minimum port
|
|
if MAX_PORT - DEFAULT_MIN_PORT - len(assigned_ports) < 100:
|
|
min_port = int(os.environ.get("JINA_RANDOM_PORT_MIN", "16384"))
|
|
else:
|
|
min_port = int(os.environ.get("JINA_RANDOM_PORT_MIN", str(DEFAULT_MIN_PORT)))
|
|
max_port = int(os.environ.get("JINA_RANDOM_PORT_MAX", str(MAX_PORT)))
|
|
return set(range(min_port, max_port + 1)) - set(assigned_ports)
|
|
|
|
unassigned_ports.clear()
|
|
assigned_ports.clear()
|
|
unassigned_ports.extend(_get_unassigned_ports())
|
|
random.shuffle(unassigned_ports)
|
|
|
|
|
|
def random_port() -> Optional[int]:
|
|
"""Get a random available port number.
|
|
|
|
:return: A random port.
|
|
"""
|
|
|
|
def _random_port():
|
|
import socket
|
|
|
|
def _check_bind(port):
|
|
with socket.socket() as s:
|
|
try:
|
|
s.bind(("", port))
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return port
|
|
except OSError:
|
|
return None
|
|
|
|
_port = None
|
|
if len(unassigned_ports) == 0:
|
|
reset_ports()
|
|
for idx, _port in enumerate(unassigned_ports):
|
|
if _check_bind(_port) is not None:
|
|
break
|
|
else:
|
|
raise OSError(
|
|
f"can not find an available port in {len(unassigned_ports)} unassigned ports, assigned already {len(assigned_ports)} ports"
|
|
)
|
|
int_port = int(_port)
|
|
unassigned_ports.pop(idx)
|
|
assigned_ports.add(int_port)
|
|
return int_port
|
|
|
|
try:
|
|
return _random_port()
|
|
except OSError:
|
|
assigned_ports.clear()
|
|
unassigned_ports.clear()
|
|
return _random_port()
|
|
|
|
|
|
class ConfigError(Exception):
|
|
"""Custom exception for configuration errors."""
|
|
|
|
pass
|
|
|
|
|
|
def load_model_configs(model_configs: str) -> dict:
|
|
"""Load and validate the model configurations .
|
|
|
|
If valid, return the configuration for the specified model name.
|
|
"""
|
|
logger = CustomLogger("models_loader")
|
|
try:
|
|
configs = json.loads(model_configs)
|
|
if not isinstance(configs, list) or not configs:
|
|
raise ConfigError("MODEL_CONFIGS must be a non-empty JSON array.")
|
|
required_keys = {"model_name", "displayName", "endpoint", "minToken", "maxToken"}
|
|
configs_map = {}
|
|
for config in configs:
|
|
missing_keys = [key for key in required_keys if key not in config]
|
|
if missing_keys:
|
|
raise ConfigError(f"Missing required configuration fields: {missing_keys}")
|
|
empty_keys = [key for key in required_keys if not config.get(key)]
|
|
if empty_keys:
|
|
raise ConfigError(f"Empty values found for configuration fields: {empty_keys}")
|
|
model_name = config["model_name"]
|
|
configs_map[model_name] = config
|
|
if not configs_map:
|
|
raise ConfigError("No valid configurations found.")
|
|
return configs_map
|
|
except json.JSONDecodeError:
|
|
logger.error("Error parsing MODEL_CONFIGS environment variable as JSON.")
|
|
raise ConfigError("MODEL_CONFIGS is not valid JSON.")
|
|
except ConfigError as e:
|
|
logger.error(str(e))
|
|
raise
|
|
|
|
|
|
def get_access_token(token_url: str, client_id: str, client_secret: str) -> str:
|
|
"""Get access token using OAuth client credentials flow."""
|
|
logger = CustomLogger("tgi_or_tei_service_auth")
|
|
data = {
|
|
"client_id": client_id,
|
|
"client_secret": client_secret,
|
|
"grant_type": "client_credentials",
|
|
}
|
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
|
response = requests.post(token_url, data=data, headers=headers)
|
|
if response.status_code == 200:
|
|
token_info = response.json()
|
|
return token_info.get("access_token", "")
|
|
else:
|
|
logger.error(f"Failed to retrieve access token: {response.status_code}, {response.text}")
|
|
return ""
|
|
|
|
|
|
class SafeContextManager:
|
|
"""This context manager ensures that the `__exit__` method of the
|
|
sub context is called, even when there is an Exception in the
|
|
`__init__` method."""
|
|
|
|
def __init__(self, context_to_manage):
|
|
self.context_to_manage = context_to_manage
|
|
|
|
def __enter__(self):
|
|
pass
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if exc_type:
|
|
self.context_to_manage.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
def handle_message(messages):
|
|
images = []
|
|
if isinstance(messages, str):
|
|
prompt = messages
|
|
else:
|
|
messages_dict = {}
|
|
system_prompt = ""
|
|
prompt = ""
|
|
for message in messages:
|
|
msg_role = message["role"]
|
|
if msg_role == "system":
|
|
system_prompt = message["content"]
|
|
elif msg_role == "user":
|
|
if type(message["content"]) == list:
|
|
text = ""
|
|
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
|
|
text += "\n".join(text_list)
|
|
image_list = [
|
|
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
|
|
]
|
|
if image_list:
|
|
messages_dict[msg_role] = (text, image_list)
|
|
else:
|
|
messages_dict[msg_role] = text
|
|
else:
|
|
messages_dict[msg_role] = message["content"]
|
|
elif msg_role == "assistant":
|
|
messages_dict[msg_role] = message["content"]
|
|
else:
|
|
raise ValueError(f"Unknown role: {msg_role}")
|
|
|
|
if system_prompt:
|
|
prompt = system_prompt + "\n"
|
|
for role, message in messages_dict.items():
|
|
if isinstance(message, tuple):
|
|
text, image_list = message
|
|
if text:
|
|
prompt += role + ": " + text + "\n"
|
|
else:
|
|
prompt += role + ":"
|
|
for img in image_list:
|
|
# URL
|
|
if img.startswith("http://") or img.startswith("https://"):
|
|
response = requests.get(img)
|
|
image = Image.open(BytesIO(response.content)).convert("RGBA")
|
|
image_bytes = BytesIO()
|
|
image.save(image_bytes, format="PNG")
|
|
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
|
|
# Local Path
|
|
elif os.path.exists(img):
|
|
image = Image.open(img).convert("RGBA")
|
|
image_bytes = BytesIO()
|
|
image.save(image_bytes, format="PNG")
|
|
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
|
|
# Bytes
|
|
else:
|
|
img_b64_str = img
|
|
|
|
images.append(img_b64_str)
|
|
else:
|
|
if message:
|
|
prompt += role + ": " + message + "\n"
|
|
else:
|
|
prompt += role + ":"
|
|
if images:
|
|
return prompt, images
|
|
else:
|
|
return prompt
|