Files
GenAIExamples/comps/cores/mega/utils.py
lkk f5efaf1f18 remove examples gateway. (#979)
* remove examples gateway.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove gateway.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine service code.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update http_service.py

* remove gateway ut.

* remove gateway ut.

* fix conflict service name.

* Update http_service.py

* add handle message ut.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove `multiprocessing.Process` start server code.

* fix ut.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove multiprocessing and enhance ut for coverage.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: chen, suyue <suyue.chen@intel.com>
2024-12-13 09:31:11 +08:00

334 lines
11 KiB
Python

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import base64
import ipaddress
import json
import multiprocessing
import os
import random
from io import BytesIO
from socket import AF_INET, SOCK_STREAM, socket
from typing import List, Optional, Union
import requests
from PIL import Image
from .logger import CustomLogger
def is_port_free(host: str, port: int) -> bool:
"""Check if a given port on a host is free.
:param host: The host to check.
:param port: The port to check.
:return: True if the port is free, False otherwise.
"""
with socket(AF_INET, SOCK_STREAM) as session:
return session.connect_ex((host, port)) != 0
def check_ports_availability(host: Union[str, List[str]], port: Union[int, List[int]]) -> bool:
"""Check if one or more ports on one or more hosts are free.
:param host: The host(s) to check.
:param port: The port(s) to check.
:return: True if all ports on all hosts are free, False otherwise.
"""
hosts = [host] if isinstance(host, str) else host
ports = [port] if isinstance(port, int) else port
return all(is_port_free(h, p) for h in hosts for p in ports)
def get_internal_ip():
"""Return the private IP address of the gateway in the same network.
:return: Private IP address.
"""
import socket
ip = "127.0.0.1"
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("10.255.255.255", 1))
ip = s.getsockname()[0]
except Exception:
pass
return ip
def get_public_ip(timeout: float = 0.3):
"""Return the public IP address of the gateway in the public network."""
import urllib.request
def _get_public_ip(url):
try:
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urllib.request.urlopen(req, timeout=timeout) as fp:
_ip = fp.read().decode().strip()
return _ip
except:
pass
ip_lookup_services = [
"https://api.ipify.org",
"https://ident.me",
"https://checkip.amazonaws.com/",
]
for _, url in enumerate(ip_lookup_services):
ip = _get_public_ip(url)
if ip:
return ip
def typename(obj):
"""Get the typename of object."""
if not isinstance(obj, type):
obj = obj.__class__
try:
return f"{obj.__module__}.{obj.__name__}"
except AttributeError:
return str(obj)
def get_event(obj) -> multiprocessing.Event:
if isinstance(obj, multiprocessing.Process) or isinstance(obj, multiprocessing.context.ForkProcess):
return multiprocessing.Event()
elif isinstance(obj, multiprocessing.context.SpawnProcess):
return multiprocessing.get_context("spawn").Event()
else:
raise TypeError(f'{obj} is not an instance of "multiprocessing.Process"')
def in_docker():
"""Checks if the current process is running inside Docker."""
path = "/proc/self/cgroup"
if os.path.exists("/.dockerenv"):
return True
if os.path.isfile(path):
with open(path, encoding="utf-8") as file:
return any("docker" in line for line in file)
return False
def host_is_local(hostname):
"""Check if hostname is point to localhost."""
import socket
fqn = socket.getfqdn(hostname)
if fqn in ("localhost", "0.0.0.0") or hostname == "0.0.0.0":
return True
try:
return ipaddress.ip_address(hostname).is_loopback
except ValueError:
return False
assigned_ports = set()
unassigned_ports = []
DEFAULT_MIN_PORT = 49153
MAX_PORT = 65535
def reset_ports():
def _get_unassigned_ports():
# if we are running out of ports, lower default minimum port
if MAX_PORT - DEFAULT_MIN_PORT - len(assigned_ports) < 100:
min_port = int(os.environ.get("JINA_RANDOM_PORT_MIN", "16384"))
else:
min_port = int(os.environ.get("JINA_RANDOM_PORT_MIN", str(DEFAULT_MIN_PORT)))
max_port = int(os.environ.get("JINA_RANDOM_PORT_MAX", str(MAX_PORT)))
return set(range(min_port, max_port + 1)) - set(assigned_ports)
unassigned_ports.clear()
assigned_ports.clear()
unassigned_ports.extend(_get_unassigned_ports())
random.shuffle(unassigned_ports)
def random_port() -> Optional[int]:
"""Get a random available port number.
:return: A random port.
"""
def _random_port():
import socket
def _check_bind(port):
with socket.socket() as s:
try:
s.bind(("", port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return port
except OSError:
return None
_port = None
if len(unassigned_ports) == 0:
reset_ports()
for idx, _port in enumerate(unassigned_ports):
if _check_bind(_port) is not None:
break
else:
raise OSError(
f"can not find an available port in {len(unassigned_ports)} unassigned ports, assigned already {len(assigned_ports)} ports"
)
int_port = int(_port)
unassigned_ports.pop(idx)
assigned_ports.add(int_port)
return int_port
try:
return _random_port()
except OSError:
assigned_ports.clear()
unassigned_ports.clear()
return _random_port()
class ConfigError(Exception):
"""Custom exception for configuration errors."""
pass
def load_model_configs(model_configs: str) -> dict:
"""Load and validate the model configurations .
If valid, return the configuration for the specified model name.
"""
logger = CustomLogger("models_loader")
try:
configs = json.loads(model_configs)
if not isinstance(configs, list) or not configs:
raise ConfigError("MODEL_CONFIGS must be a non-empty JSON array.")
required_keys = {"model_name", "displayName", "endpoint", "minToken", "maxToken"}
configs_map = {}
for config in configs:
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ConfigError(f"Missing required configuration fields: {missing_keys}")
empty_keys = [key for key in required_keys if not config.get(key)]
if empty_keys:
raise ConfigError(f"Empty values found for configuration fields: {empty_keys}")
model_name = config["model_name"]
configs_map[model_name] = config
if not configs_map:
raise ConfigError("No valid configurations found.")
return configs_map
except json.JSONDecodeError:
logger.error("Error parsing MODEL_CONFIGS environment variable as JSON.")
raise ConfigError("MODEL_CONFIGS is not valid JSON.")
except ConfigError as e:
logger.error(str(e))
raise
def get_access_token(token_url: str, client_id: str, client_secret: str) -> str:
"""Get access token using OAuth client credentials flow."""
logger = CustomLogger("tgi_or_tei_service_auth")
data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "client_credentials",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(token_url, data=data, headers=headers)
if response.status_code == 200:
token_info = response.json()
return token_info.get("access_token", "")
else:
logger.error(f"Failed to retrieve access token: {response.status_code}, {response.text}")
return ""
class SafeContextManager:
"""This context manager ensures that the `__exit__` method of the
sub context is called, even when there is an Exception in the
`__init__` method."""
def __init__(self, context_to_manage):
self.context_to_manage = context_to_manage
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self.context_to_manage.__exit__(exc_type, exc_val, exc_tb)
def handle_message(messages):
images = []
if isinstance(messages, str):
prompt = messages
else:
messages_dict = {}
system_prompt = ""
prompt = ""
for message in messages:
msg_role = message["role"]
if msg_role == "system":
system_prompt = message["content"]
elif msg_role == "user":
if type(message["content"]) == list:
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
messages_dict[msg_role] = text
else:
messages_dict[msg_role] = message["content"]
elif msg_role == "assistant":
messages_dict[msg_role] = message["content"]
else:
raise ValueError(f"Unknown role: {msg_role}")
if system_prompt:
prompt = system_prompt + "\n"
for role, message in messages_dict.items():
if isinstance(message, tuple):
text, image_list = message
if text:
prompt += role + ": " + text + "\n"
else:
prompt += role + ":"
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img
images.append(img_b64_str)
else:
if message:
prompt += role + ": " + message + "\n"
else:
prompt += role + ":"
if images:
return prompt, images
else:
return prompt