167 lines
5.1 KiB
Python
167 lines
5.1 KiB
Python
import torch
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
from PIL import Image
|
|
import io
|
|
import json
|
|
from qdrant_client import QdrantClient
|
|
import base64
|
|
import logging
|
|
from fastapi import FastAPI, Request
|
|
import uvicorn
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
import multiprocessing
|
|
from fastapi import FastAPI, File, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
|
|
#from pyngrok import ngrok, conf
|
|
|
|
#conf.get_default().auth_token = "31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz"
|
|
|
|
# 1. Download Model
|
|
# python downloadModel.py
|
|
|
|
# 2. Command to run with splunk observability
|
|
# opentelemetry-instrument python server.py
|
|
|
|
# 3. Command to run ngrok
|
|
# ngrok http --url=pegasus-working-bison.ngrok-free.app 8000
|
|
|
|
# ngrok config add-authtoken 31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz
|
|
# ngrop http --url=pegasus-working-bison.ngrok-free.app 8000
|
|
|
|
# ==============================
|
|
# Logging
|
|
# ==============================
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
# ==============================
|
|
# Model & DB initialization
|
|
# ==============================
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
logging.info(f"Using device: {device}")
|
|
|
|
# Use model directly from HuggingFace (auto-downloads and caches)
|
|
model_path = "openai/clip-vit-base-patch32" # Using base model for faster loading
|
|
|
|
logging.info(f"Loading CLIP model from HuggingFace: {model_path}...")
|
|
logging.info("(Model will be cached after first download)")
|
|
model = CLIPModel.from_pretrained(model_path).to(device)
|
|
processor = CLIPProcessor.from_pretrained(model_path)
|
|
model.eval()
|
|
logging.info("Model loaded successfully.")
|
|
|
|
# Qdrant connection
|
|
'''qdrant = QdrantClient(
|
|
url="https://q-vector-db.beprime.mx",
|
|
port=443,
|
|
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
|
|
https=True
|
|
)'''
|
|
qdrant = QdrantClient(
|
|
host="172.21.4.201",
|
|
port=5050,
|
|
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
|
|
https=False
|
|
)
|
|
COLLECTION_NAME = "nazan-ssa"
|
|
|
|
# ==============================
|
|
# Helper functions
|
|
# ==============================
|
|
def get_embeddings(image_bytes: bytes):
|
|
try:
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
inputs = processor(images=image, return_tensors="pt")
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
embedding = model.get_image_features(**inputs)
|
|
embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
|
|
return embedding.cpu().numpy().astype("float32").flatten()
|
|
except Exception as e:
|
|
logging.error(f"Error in get_embeddings: {e}")
|
|
raise
|
|
|
|
def similarity_search(embedding, top_k=1, threshold=0.7):
|
|
try:
|
|
embedding_list = embedding.astype(float).tolist()
|
|
results = qdrant.search(
|
|
collection_name=COLLECTION_NAME,
|
|
query_vector=embedding_list,
|
|
limit=top_k
|
|
)
|
|
points = getattr(results, "points", results) # v3/v4 compatibility
|
|
if not points:
|
|
return {"status": False, "SKU": None}
|
|
best = points[0]
|
|
if best.score >= threshold:
|
|
print({"status": True, "SKU": best.payload.get("product_id")}, "threshold:", best.score)
|
|
return {"status": True, "SKU": best.payload.get("product_id")}
|
|
print({"status": False, "SKU": None}, "threshold:", best.score)
|
|
return {"status": False, "SKU": None}
|
|
except Exception as e:
|
|
logging.error(f"Error in similarity_search: {e}")
|
|
|
|
# ==============================
|
|
# FastAPI app
|
|
# ==============================
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # or list of domains
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.post("/predict")
|
|
async def predict(request: Request):
|
|
try:
|
|
data = await request.json()
|
|
if "image_bytes" not in data:
|
|
return {"status": False, "error": "Missing 'image_bytes'"}
|
|
image_bytes = base64.b64decode(data["image_bytes"])
|
|
embedding = get_embeddings(image_bytes)
|
|
result = similarity_search(embedding)
|
|
return result
|
|
except Exception as e:
|
|
logging.error(f"Error in /predict: {e}")
|
|
return {"status": False, "error": str(e)}
|
|
|
|
@app.post("/predictfile")
|
|
async def predict_file(file: UploadFile = File(...)):
|
|
try:
|
|
# Read uploaded file as bytes
|
|
image_bytes = await file.read()
|
|
|
|
# Generate embedding
|
|
embedding = get_embeddings(image_bytes)
|
|
|
|
# Run similarity search
|
|
result = similarity_search(embedding)
|
|
|
|
return result#JSONResponse(content=result)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in /predictfile: {e}")
|
|
return JSONResponse(content={"status": False, "error": str(e)})
|
|
|
|
|
|
|
|
|
|
def run_server():
|
|
uvicorn.run("server:app", host="0.0.0.0", port=8000)
|
|
|
|
if __name__ == "__main__":
|
|
# Open ngrok tunnel
|
|
#public_url = ngrok.connect(8000).public_url
|
|
#print("🌍 Public URL:", public_url)
|
|
|
|
# Run FastAPI server in a subprocess
|
|
#p = multiprocessing.Process(target=run_server)
|
|
#p.start()
|
|
#p.join()
|
|
uvicorn.run("server:app", host="0.0.0.0", port=8000)
|
|
|