- Added python-server folder

This commit is contained in:
2025-09-01 11:31:10 -06:00
parent 9ab56dfbfe
commit e8ad6dd99a
4 changed files with 205 additions and 0 deletions

Binary file not shown.

View File

@@ -0,0 +1,21 @@
from transformers import CLIPModel, CLIPProcessor
# pip install torch torchvision torchaudio
# pip install transformers
# Local folder where you want to save the model
#local_model_path = "models/clip-vit-base-patch32"
local_model_path = "models/clip-vit-large-patch14"
# Load model & processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Save them locally
model.save_pretrained(local_model_path)
processor.save_pretrained(local_model_path)
print(f"Model saved to {local_model_path}")
# Then, package model in cmd
# tar -czvf clip_model.tar.gz mode

164
python_server/server.py Normal file
View File

@@ -0,0 +1,164 @@
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import io
import json
from qdrant_client import QdrantClient
import base64
import logging
from fastapi import FastAPI, Request
import uvicorn
#from pyngrok import ngrok, conf
#conf.get_default().auth_token = "31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz"
# ngrok config add-authtoken 31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz
# ngrop http --url=pegasus-working-bison.ngrok-free.app 8000
# Command to run with splunk observability
# opentelemetry-instrument python server.py
from fastapi.middleware.cors import CORSMiddleware
# ==============================
# Logging
# ==============================
logging.basicConfig(level=logging.INFO)
# ==============================
# Model & DB initialization
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# Adjust this path for local model folder
#model_path = "./model/clip-vit-base-patch32"
model_path = "models/clip-vit-large-patch14"
logging.info("Loading CLIP model...")
model = CLIPModel.from_pretrained(model_path).to(device)
processor = CLIPProcessor.from_pretrained(model_path)
model.eval()
logging.info("Model loaded.")
# Qdrant connection
'''qdrant = QdrantClient(
url="https://q-vector-db.beprime.mx",
port=443,
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
https=True
)'''
qdrant = QdrantClient(
host="172.21.4.201",
port=5050,
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
https=False
)
COLLECTION_NAME = "nazan-ssa"
# ==============================
# Helper functions
# ==============================
def get_embeddings(image_bytes: bytes):
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
embedding = model.get_image_features(**inputs)
embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
return embedding.cpu().numpy().astype("float32").flatten()
except Exception as e:
logging.error(f"Error in get_embeddings: {e}")
raise
def similarity_search(embedding, top_k=1, threshold=0.7):
try:
embedding_list = embedding.astype(float).tolist()
results = qdrant.search(
collection_name=COLLECTION_NAME,
query_vector=embedding_list,
limit=top_k
)
points = getattr(results, "points", results) # v3/v4 compatibility
if not points:
return {"status": False, "SKU": None}
best = points[0]
if best.score >= threshold:
print({"status": True, "SKU": best.payload.get("sku")}, "threshold:", best.score)
return {"status": True, "SKU": best.payload.get("sku")}
print({"status": False, "SKU": None}, "threshold:", best.score)
return {"status": False, "SKU": None}
except Exception as e:
logging.error(f"Error in similarity_search: {e}")
# ==============================
# FastAPI app
# ==============================
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or list of domains
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
if "image_bytes" not in data:
return {"status": False, "error": "Missing 'image_bytes'"}
image_bytes = base64.b64decode(data["image_bytes"])
embedding = get_embeddings(image_bytes)
result = similarity_search(embedding)
return result
except Exception as e:
logging.error(f"Error in /predict: {e}")
return {"status": False, "error": str(e)}
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
@app.post("/predictfile")
async def predict_file(file: UploadFile = File(...)):
try:
# Read uploaded file as bytes
image_bytes = await file.read()
# Generate embedding
embedding = get_embeddings(image_bytes)
# Run similarity search
result = similarity_search(embedding)
return result#JSONResponse(content=result)
except Exception as e:
logging.error(f"Error in /predictfile: {e}")
return JSONResponse(content={"status": False, "error": str(e)})
#from pyngrok import ngrokS
import uvicorn
import multiprocessing
def run_server():
uvicorn.run("server:app", host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Open ngrok tunnel
#public_url = ngrok.connect(8000).public_url
#print("🌍 Public URL:", public_url)
# Run FastAPI server in a subprocess
#p = multiprocessing.Process(target=run_server)
#p.start()
#p.join()
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000)

20
python_server/windmill.py Normal file
View File

@@ -0,0 +1,20 @@
import wmill
from wmill import S3Object
import base64
import json
import requests
LOCAL_ENDPOINT = "http://127.0.0.1:8000/predict"
def main(input_file: S3Object):
# Load image from S3
s3_bytes = wmill.load_s3_file(input_file)
payload = {"image_bytes": base64.b64encode(s3_bytes).decode("utf-8")}
print("CALLING LOCAL ENDPOINT")
response = requests.post(LOCAL_ENDPOINT, json=payload)
result = response.json()
print("RESULT RECEIVED")
return result