import torch from transformers import CLIPProcessor, CLIPModel from PIL import Image import io import json from qdrant_client import QdrantClient import base64 import logging from fastapi import FastAPI, Request import uvicorn from fastapi.middleware.cors import CORSMiddleware import uvicorn import multiprocessing from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse #from pyngrok import ngrok, conf #conf.get_default().auth_token = "31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz" # 1. Download Model # python downloadModel.py # 2. Command to run with splunk observability # opentelemetry-instrument python server.py # 3. Command to run ngrok # ngrok http --url=pegasus-working-bison.ngrok-free.app 8000 # ngrok config add-authtoken 31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz # ngrop http --url=pegasus-working-bison.ngrok-free.app 8000 # ============================== # Logging # ============================== logging.basicConfig(level=logging.INFO) # ============================== # Model & DB initialization # ============================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device: {device}") # Use model directly from HuggingFace (auto-downloads and caches) model_path = "openai/clip-vit-base-patch32" # Using base model for faster loading logging.info(f"Loading CLIP model from HuggingFace: {model_path}...") logging.info("(Model will be cached after first download)") model = CLIPModel.from_pretrained(model_path).to(device) processor = CLIPProcessor.from_pretrained(model_path) model.eval() logging.info("Model loaded successfully.") # Qdrant connection '''qdrant = QdrantClient( url="https://q-vector-db.beprime.mx", port=443, api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z", https=True )''' qdrant = QdrantClient( host="172.21.4.201", port=5050, api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z", https=False ) COLLECTION_NAME = "nazan-ssa" # ============================== # Helper functions # ============================== def get_embeddings(image_bytes: bytes): try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): embedding = model.get_image_features(**inputs) embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True) return embedding.cpu().numpy().astype("float32").flatten() except Exception as e: logging.error(f"Error in get_embeddings: {e}") raise def similarity_search(embedding, top_k=1, threshold=0.7): try: embedding_list = embedding.astype(float).tolist() results = qdrant.search( collection_name=COLLECTION_NAME, query_vector=embedding_list, limit=top_k ) points = getattr(results, "points", results) # v3/v4 compatibility if not points: return {"status": False, "SKU": None} best = points[0] if best.score >= threshold: print({"status": True, "SKU": best.payload.get("product_id")}, "threshold:", best.score) return {"status": True, "SKU": best.payload.get("product_id")} print({"status": False, "SKU": None}, "threshold:", best.score) return {"status": False, "SKU": None} except Exception as e: logging.error(f"Error in similarity_search: {e}") # ============================== # FastAPI app # ============================== app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # or list of domains allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/predict") async def predict(request: Request): try: data = await request.json() if "image_bytes" not in data: return {"status": False, "error": "Missing 'image_bytes'"} image_bytes = base64.b64decode(data["image_bytes"]) embedding = get_embeddings(image_bytes) result = similarity_search(embedding) return result except Exception as e: logging.error(f"Error in /predict: {e}") return {"status": False, "error": str(e)} @app.post("/predictfile") async def predict_file(file: UploadFile = File(...)): try: # Read uploaded file as bytes image_bytes = await file.read() # Generate embedding embedding = get_embeddings(image_bytes) # Run similarity search result = similarity_search(embedding) return result#JSONResponse(content=result) except Exception as e: logging.error(f"Error in /predictfile: {e}") return JSONResponse(content={"status": False, "error": str(e)}) def run_server(): uvicorn.run("server:app", host="0.0.0.0", port=8000) if __name__ == "__main__": # Open ngrok tunnel #public_url = ngrok.connect(8000).public_url #print("🌍 Public URL:", public_url) # Run FastAPI server in a subprocess #p = multiprocessing.Process(target=run_server) #p.start() #p.join() uvicorn.run("server:app", host="0.0.0.0", port=8000)