NeerjaK's picture
Can use zipfile!
e4072cb verified
from io import BytesIO
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from utils import get_patch_embeddings, image_query_by_patch
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os
import zipfile
zip_path = "deepfashion_images.zip"
extract_dir = "deepfashion_images"
if not os.path.exists(extract_dir):
print("Extracting images...")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(extract_dir)
print("βœ… Extraction complete!")
else:
print("βœ… Images already extracted.")
data = torch.load("deep_fashion_patch_embeddings.pt", map_location=device)
patch_embeddings = data["patch_embeddings"]
image_paths = data["image_paths"]
def plot_topk_images(query_img, alpha, k=5, show_overlay=True):
patch_embs_query = get_patch_embeddings(query_img, ps=16)
rows, cols, dim = patch_embs_query.shape
query_patch_emb = patch_embs_query[rows//2, cols//2]
# Define the grid size (e.g., 3x3 central patches)
# grid_size = 4
# start_row = rows//2 - grid_size//2
# start_col = cols//2 - grid_size//2
# # Clip to stay within bounds
# start_row = max(0, start_row)
# start_col = max(0, start_col)
# # Extract grid patches and average their embeddings
# central_grid = patch_embs_query[start_row:start_row+grid_size, start_col:start_col+grid_size, :]
# query_patch_emb = central_grid.reshape(-1, central_grid.shape[-1]).mean(axis=0)
query_patch_emb = torch.from_numpy(query_patch_emb).float().to(device=device, dtype=torch.float16)
topk_indices, topk_scores = image_query_by_patch(query_patch_emb, patch_embeddings, k)
all_imgs = []
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(query_img)
if show_overlay:
patch_embs_flat_query = torch.from_numpy(patch_embs_query.reshape(-1, dim)).float().to(device=device, dtype=torch.float16)
sims_query = (patch_embs_flat_query @ query_patch_emb.view(-1,1)).squeeze().cpu().numpy().reshape(rows, cols)
ax.imshow(sims_query, cmap='jet', alpha=alpha, extent=(0, query_img.width, query_img.height, 0))
ax.set_title("Query + Heatmap")
ax.axis('off')
buf = BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
all_imgs.append(Image.open(buf))
plt.close(fig)
# Top-k images
for idx, score in zip(topk_indices, topk_scores):
sim_img_path = 'deepfashion_images/'+image_paths[idx]
sim_img = Image.open(sim_img_path).convert("RGB")
if show_overlay:
patch_embs_sim = patch_embeddings[idx]
num_patches, dim = patch_embs_sim.shape
cols_sim = int(np.sqrt(num_patches))
rows_sim = (num_patches + cols_sim - 1) // cols_sim
patch_embs_flat_sim = patch_embs_sim.reshape(-1, dim).float().to(device=device, dtype=torch.float16)
sims = (patch_embs_flat_sim @ query_patch_emb.view(-1,1)).squeeze().cpu().numpy().reshape(rows_sim, cols_sim)
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(sim_img)
ax.imshow(sims, cmap='jet', alpha=alpha, extent=(0, sim_img.width, sim_img.height, 0))
ax.set_title(f"{score:.2f}")
ax.axis('off')
buf = BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
all_imgs.append(Image.open(buf))
plt.close(fig)
else:
all_imgs.append(sim_img)
return all_imgs
SAMPLE_IMAGES = {
"Blue T-Shirt": "blue-tshirt.jpg",
"Levi's White T-Shirt": "levis tee.jpg",
"Striped Shirt": "striped-shirt.jpg",
}
def update_input_img(sample_choice):
if sample_choice in SAMPLE_IMAGES:
return Image.open(SAMPLE_IMAGES[sample_choice]).convert("RGB")
return None
def get_query_image(uploaded_img):
return uploaded_img
with gr.Blocks() as demo:
gr.Markdown("""
<h1 style="font-size:36px; font-weight:bold;">Visual Fashion Search using DINOv3 Embeddings</h1>
<p style="font-size:18px;"> This tool allows you to find similar fashion items using image embeddings. To use: <br>
<ul style="font-size:18px;">
<li>Upload an image, choose a sample image or use your webcam!</li>
<li>Select the value of k for top-k results and alpha for the overlay.</li>
</ul>
""")
with gr.Row():
input_img = gr.Image(label="Upload Image", type="pil")
sample_choice = gr.Dropdown(choices=list(SAMPLE_IMAGES.keys()), label="Or select a sample image")
with gr.Column():
k_input = gr.Slider(minimum=1, maximum=50, step=1, value=5, label="Number of Results (k)")
alpha_input = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Overlay Transparency (alpha)")
overlay_input = gr.Checkbox(label="Show Overlay", value=False)
gallery_output = gr.Gallery(label="Results", columns=5)
btn = gr.Button("Search")
sample_choice.change(fn=update_input_img, inputs=[sample_choice], outputs=[input_img])
btn.click(fn=plot_topk_images,
inputs=[input_img, alpha_input, k_input, overlay_input],
outputs=gallery_output)
demo.launch(share=True)