Files
temp_SSA_SCAN/python_server/server.py

167 lines
5.1 KiB
Python

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import io
import json
from qdrant_client import QdrantClient
import base64
import logging
from fastapi import FastAPI, Request
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import multiprocessing
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
#from pyngrok import ngrok, conf
#conf.get_default().auth_token = "31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz"
# 1. Download Model
# python downloadModel.py
# 2. Command to run with splunk observability
# opentelemetry-instrument python server.py
# 3. Command to run ngrok
# ngrok http --url=pegasus-working-bison.ngrok-free.app 8000
# ngrok config add-authtoken 31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz
# ngrop http --url=pegasus-working-bison.ngrok-free.app 8000
# ==============================
# Logging
# ==============================
logging.basicConfig(level=logging.INFO)
# ==============================
# Model & DB initialization
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# Use model directly from HuggingFace (auto-downloads and caches)
model_path = "openai/clip-vit-base-patch32" # Using base model for faster loading
logging.info(f"Loading CLIP model from HuggingFace: {model_path}...")
logging.info("(Model will be cached after first download)")
model = CLIPModel.from_pretrained(model_path).to(device)
processor = CLIPProcessor.from_pretrained(model_path)
model.eval()
logging.info("Model loaded successfully.")
# Qdrant connection
'''qdrant = QdrantClient(
url="https://q-vector-db.beprime.mx",
port=443,
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
https=True
)'''
qdrant = QdrantClient(
host="172.21.4.201",
port=5050,
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
https=False
)
COLLECTION_NAME = "nazan-ssa"
# ==============================
# Helper functions
# ==============================
def get_embeddings(image_bytes: bytes):
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
embedding = model.get_image_features(**inputs)
embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
return embedding.cpu().numpy().astype("float32").flatten()
except Exception as e:
logging.error(f"Error in get_embeddings: {e}")
raise
def similarity_search(embedding, top_k=1, threshold=0.7):
try:
embedding_list = embedding.astype(float).tolist()
results = qdrant.search(
collection_name=COLLECTION_NAME,
query_vector=embedding_list,
limit=top_k
)
points = getattr(results, "points", results) # v3/v4 compatibility
if not points:
return {"status": False, "SKU": None}
best = points[0]
if best.score >= threshold:
print({"status": True, "SKU": best.payload.get("product_id")}, "threshold:", best.score)
return {"status": True, "SKU": best.payload.get("product_id")}
print({"status": False, "SKU": None}, "threshold:", best.score)
return {"status": False, "SKU": None}
except Exception as e:
logging.error(f"Error in similarity_search: {e}")
# ==============================
# FastAPI app
# ==============================
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or list of domains
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
if "image_bytes" not in data:
return {"status": False, "error": "Missing 'image_bytes'"}
image_bytes = base64.b64decode(data["image_bytes"])
embedding = get_embeddings(image_bytes)
result = similarity_search(embedding)
return result
except Exception as e:
logging.error(f"Error in /predict: {e}")
return {"status": False, "error": str(e)}
@app.post("/predictfile")
async def predict_file(file: UploadFile = File(...)):
try:
# Read uploaded file as bytes
image_bytes = await file.read()
# Generate embedding
embedding = get_embeddings(image_bytes)
# Run similarity search
result = similarity_search(embedding)
return result#JSONResponse(content=result)
except Exception as e:
logging.error(f"Error in /predictfile: {e}")
return JSONResponse(content={"status": False, "error": str(e)})
def run_server():
uvicorn.run("server:app", host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Open ngrok tunnel
#public_url = ngrok.connect(8000).public_url
#print("🌍 Public URL:", public_url)
# Run FastAPI server in a subprocess
#p = multiprocessing.Process(target=run_server)
#p.start()
#p.join()
uvicorn.run("server:app", host="0.0.0.0", port=8000)