Files
temp_SSA_SCAN/python_server/server.py
Jorge Arturo Mendez Vargas 4ea7d1516b - Added /model to git ignore.
- Modified server file a bit.
- Added example Dockerfile
2025-09-01 12:14:21 -06:00

170 lines
5.0 KiB
Python

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import io
import json
from qdrant_client import QdrantClient
import base64
import logging
from fastapi import FastAPI, Request
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
#from pyngrok import ngrok, conf
#conf.get_default().auth_token = "31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz"
# 1. Download Model
# python downloadModel.py
# 2. Command to run with splunk observability
# opentelemetry-instrument python server.py
# 3. Command to run ngrok
# ngrok http --url=pegasus-working-bison.ngrok-free.app 8000
# ngrok config add-authtoken 31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz
# ngrop http --url=pegasus-working-bison.ngrok-free.app 8000
# ==============================
# Logging
# ==============================
logging.basicConfig(level=logging.INFO)
# ==============================
# Model & DB initialization
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# Adjust this path for local model folder
#model_path = "./model/clip-vit-base-patch32"
model_path = "models/clip-vit-large-patch14"
logging.info("Loading CLIP model...")
model = CLIPModel.from_pretrained(model_path).to(device)
processor = CLIPProcessor.from_pretrained(model_path)
model.eval()
logging.info("Model loaded.")
# Qdrant connection
'''qdrant = QdrantClient(
url="https://q-vector-db.beprime.mx",
port=443,
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
https=True
)'''
qdrant = QdrantClient(
host="172.21.4.201",
port=5050,
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
https=False
)
COLLECTION_NAME = "nazan-ssa"
# ==============================
# Helper functions
# ==============================
def get_embeddings(image_bytes: bytes):
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
embedding = model.get_image_features(**inputs)
embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
return embedding.cpu().numpy().astype("float32").flatten()
except Exception as e:
logging.error(f"Error in get_embeddings: {e}")
raise
def similarity_search(embedding, top_k=1, threshold=0.7):
try:
embedding_list = embedding.astype(float).tolist()
results = qdrant.search(
collection_name=COLLECTION_NAME,
query_vector=embedding_list,
limit=top_k
)
points = getattr(results, "points", results) # v3/v4 compatibility
if not points:
return {"status": False, "SKU": None}
best = points[0]
if best.score >= threshold:
print({"status": True, "SKU": best.payload.get("sku")}, "threshold:", best.score)
return {"status": True, "SKU": best.payload.get("sku")}
print({"status": False, "SKU": None}, "threshold:", best.score)
return {"status": False, "SKU": None}
except Exception as e:
logging.error(f"Error in similarity_search: {e}")
# ==============================
# FastAPI app
# ==============================
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or list of domains
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
if "image_bytes" not in data:
return {"status": False, "error": "Missing 'image_bytes'"}
image_bytes = base64.b64decode(data["image_bytes"])
embedding = get_embeddings(image_bytes)
result = similarity_search(embedding)
return result
except Exception as e:
logging.error(f"Error in /predict: {e}")
return {"status": False, "error": str(e)}
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
@app.post("/predictfile")
async def predict_file(file: UploadFile = File(...)):
try:
# Read uploaded file as bytes
image_bytes = await file.read()
# Generate embedding
embedding = get_embeddings(image_bytes)
# Run similarity search
result = similarity_search(embedding)
return result#JSONResponse(content=result)
except Exception as e:
logging.error(f"Error in /predictfile: {e}")
return JSONResponse(content={"status": False, "error": str(e)})
#from pyngrok import ngrokS
import uvicorn
import multiprocessing
def run_server():
uvicorn.run("server:app", host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Open ngrok tunnel
#public_url = ngrok.connect(8000).public_url
#print("🌍 Public URL:", public_url)
# Run FastAPI server in a subprocess
#p = multiprocessing.Process(target=run_server)
#p.start()
#p.join()
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000)