- Added python-server folder
This commit is contained in:
BIN
python_server/__pycache__/server.cpython-313.pyc
Normal file
BIN
python_server/__pycache__/server.cpython-313.pyc
Normal file
Binary file not shown.
21
python_server/downloadModel.py
Normal file
21
python_server/downloadModel.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
# pip install torch torchvision torchaudio
|
||||
# pip install transformers
|
||||
|
||||
# Local folder where you want to save the model
|
||||
#local_model_path = "models/clip-vit-base-patch32"
|
||||
local_model_path = "models/clip-vit-large-patch14"
|
||||
|
||||
# Load model & processor
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
# Save them locally
|
||||
model.save_pretrained(local_model_path)
|
||||
processor.save_pretrained(local_model_path)
|
||||
|
||||
print(f"Model saved to {local_model_path}")
|
||||
|
||||
# Then, package model in cmd
|
||||
# tar -czvf clip_model.tar.gz mode
|
||||
164
python_server/server.py
Normal file
164
python_server/server.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from PIL import Image
|
||||
import io
|
||||
import json
|
||||
from qdrant_client import QdrantClient
|
||||
import base64
|
||||
import logging
|
||||
from fastapi import FastAPI, Request
|
||||
import uvicorn
|
||||
|
||||
#from pyngrok import ngrok, conf
|
||||
|
||||
#conf.get_default().auth_token = "31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz"
|
||||
|
||||
# ngrok config add-authtoken 31y03xgLo8cpZ0WujlWkTCT34qK_86cXFLfdH2c3gX3SEezaz
|
||||
# ngrop http --url=pegasus-working-bison.ngrok-free.app 8000
|
||||
|
||||
# Command to run with splunk observability
|
||||
# opentelemetry-instrument python server.py
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# ==============================
|
||||
# Logging
|
||||
# ==============================
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ==============================
|
||||
# Model & DB initialization
|
||||
# ==============================
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Using device: {device}")
|
||||
|
||||
# Adjust this path for local model folder
|
||||
#model_path = "./model/clip-vit-base-patch32"
|
||||
model_path = "models/clip-vit-large-patch14"
|
||||
|
||||
|
||||
logging.info("Loading CLIP model...")
|
||||
model = CLIPModel.from_pretrained(model_path).to(device)
|
||||
processor = CLIPProcessor.from_pretrained(model_path)
|
||||
model.eval()
|
||||
logging.info("Model loaded.")
|
||||
|
||||
# Qdrant connection
|
||||
'''qdrant = QdrantClient(
|
||||
url="https://q-vector-db.beprime.mx",
|
||||
port=443,
|
||||
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
|
||||
https=True
|
||||
)'''
|
||||
qdrant = QdrantClient(
|
||||
host="172.21.4.201",
|
||||
port=5050,
|
||||
api_key="GJY54XaG0B94DlKAt5IH9T1Ez67u4R7Z",
|
||||
https=False
|
||||
)
|
||||
COLLECTION_NAME = "nazan-ssa"
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
def get_embeddings(image_bytes: bytes):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
embedding = model.get_image_features(**inputs)
|
||||
embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
|
||||
return embedding.cpu().numpy().astype("float32").flatten()
|
||||
except Exception as e:
|
||||
logging.error(f"Error in get_embeddings: {e}")
|
||||
raise
|
||||
|
||||
def similarity_search(embedding, top_k=1, threshold=0.7):
|
||||
try:
|
||||
embedding_list = embedding.astype(float).tolist()
|
||||
results = qdrant.search(
|
||||
collection_name=COLLECTION_NAME,
|
||||
query_vector=embedding_list,
|
||||
limit=top_k
|
||||
)
|
||||
points = getattr(results, "points", results) # v3/v4 compatibility
|
||||
if not points:
|
||||
return {"status": False, "SKU": None}
|
||||
best = points[0]
|
||||
if best.score >= threshold:
|
||||
print({"status": True, "SKU": best.payload.get("sku")}, "threshold:", best.score)
|
||||
return {"status": True, "SKU": best.payload.get("sku")}
|
||||
print({"status": False, "SKU": None}, "threshold:", best.score)
|
||||
return {"status": False, "SKU": None}
|
||||
except Exception as e:
|
||||
logging.error(f"Error in similarity_search: {e}")
|
||||
|
||||
# ==============================
|
||||
# FastAPI app
|
||||
# ==============================
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # or list of domains
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.post("/predict")
|
||||
async def predict(request: Request):
|
||||
try:
|
||||
data = await request.json()
|
||||
if "image_bytes" not in data:
|
||||
return {"status": False, "error": "Missing 'image_bytes'"}
|
||||
image_bytes = base64.b64decode(data["image_bytes"])
|
||||
embedding = get_embeddings(image_bytes)
|
||||
result = similarity_search(embedding)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in /predict: {e}")
|
||||
return {"status": False, "error": str(e)}
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@app.post("/predictfile")
|
||||
async def predict_file(file: UploadFile = File(...)):
|
||||
try:
|
||||
# Read uploaded file as bytes
|
||||
image_bytes = await file.read()
|
||||
|
||||
# Generate embedding
|
||||
embedding = get_embeddings(image_bytes)
|
||||
|
||||
# Run similarity search
|
||||
result = similarity_search(embedding)
|
||||
|
||||
return result#JSONResponse(content=result)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in /predictfile: {e}")
|
||||
return JSONResponse(content={"status": False, "error": str(e)})
|
||||
|
||||
|
||||
#from pyngrok import ngrokS
|
||||
import uvicorn
|
||||
import multiprocessing
|
||||
|
||||
def run_server():
|
||||
uvicorn.run("server:app", host="0.0.0.0", port=8000)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Open ngrok tunnel
|
||||
#public_url = ngrok.connect(8000).public_url
|
||||
#print("🌍 Public URL:", public_url)
|
||||
|
||||
# Run FastAPI server in a subprocess
|
||||
#p = multiprocessing.Process(target=run_server)
|
||||
#p.start()
|
||||
#p.join()
|
||||
import uvicorn
|
||||
uvicorn.run("server:app", host="0.0.0.0", port=8000)
|
||||
|
||||
20
python_server/windmill.py
Normal file
20
python_server/windmill.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import wmill
|
||||
from wmill import S3Object
|
||||
import base64
|
||||
import json
|
||||
import requests
|
||||
|
||||
LOCAL_ENDPOINT = "http://127.0.0.1:8000/predict"
|
||||
|
||||
def main(input_file: S3Object):
|
||||
# Load image from S3
|
||||
s3_bytes = wmill.load_s3_file(input_file)
|
||||
payload = {"image_bytes": base64.b64encode(s3_bytes).decode("utf-8")}
|
||||
|
||||
print("CALLING LOCAL ENDPOINT")
|
||||
|
||||
response = requests.post(LOCAL_ENDPOINT, json=payload)
|
||||
result = response.json()
|
||||
|
||||
print("RESULT RECEIVED")
|
||||
return result
|
||||
Reference in New Issue
Block a user