|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import gradio as gr |
|
|
from utils import get_patch_embeddings, image_query_by_patch |
|
|
import torch |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
import os |
|
|
import zipfile |
|
|
|
|
|
zip_path = "deepfashion_images.zip" |
|
|
extract_dir = "deepfashion_images" |
|
|
|
|
|
if not os.path.exists(extract_dir): |
|
|
print("Extracting images...") |
|
|
with zipfile.ZipFile(zip_path, "r") as zf: |
|
|
zf.extractall(extract_dir) |
|
|
print("β
Extraction complete!") |
|
|
else: |
|
|
print("β
Images already extracted.") |
|
|
|
|
|
data = torch.load("deep_fashion_patch_embeddings.pt", map_location=device) |
|
|
patch_embeddings = data["patch_embeddings"] |
|
|
image_paths = data["image_paths"] |
|
|
|
|
|
def plot_topk_images(query_img, alpha, k=5, show_overlay=True): |
|
|
patch_embs_query = get_patch_embeddings(query_img, ps=16) |
|
|
rows, cols, dim = patch_embs_query.shape |
|
|
|
|
|
query_patch_emb = patch_embs_query[rows//2, cols//2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_patch_emb = torch.from_numpy(query_patch_emb).float().to(device=device, dtype=torch.float16) |
|
|
|
|
|
topk_indices, topk_scores = image_query_by_patch(query_patch_emb, patch_embeddings, k) |
|
|
|
|
|
all_imgs = [] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(4, 4)) |
|
|
ax.imshow(query_img) |
|
|
if show_overlay: |
|
|
patch_embs_flat_query = torch.from_numpy(patch_embs_query.reshape(-1, dim)).float().to(device=device, dtype=torch.float16) |
|
|
sims_query = (patch_embs_flat_query @ query_patch_emb.view(-1,1)).squeeze().cpu().numpy().reshape(rows, cols) |
|
|
ax.imshow(sims_query, cmap='jet', alpha=alpha, extent=(0, query_img.width, query_img.height, 0)) |
|
|
ax.set_title("Query + Heatmap") |
|
|
ax.axis('off') |
|
|
|
|
|
buf = BytesIO() |
|
|
fig.savefig(buf, format="png") |
|
|
buf.seek(0) |
|
|
all_imgs.append(Image.open(buf)) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
for idx, score in zip(topk_indices, topk_scores): |
|
|
sim_img_path = 'deepfashion_images/'+image_paths[idx] |
|
|
sim_img = Image.open(sim_img_path).convert("RGB") |
|
|
|
|
|
if show_overlay: |
|
|
patch_embs_sim = patch_embeddings[idx] |
|
|
num_patches, dim = patch_embs_sim.shape |
|
|
cols_sim = int(np.sqrt(num_patches)) |
|
|
rows_sim = (num_patches + cols_sim - 1) // cols_sim |
|
|
patch_embs_flat_sim = patch_embs_sim.reshape(-1, dim).float().to(device=device, dtype=torch.float16) |
|
|
sims = (patch_embs_flat_sim @ query_patch_emb.view(-1,1)).squeeze().cpu().numpy().reshape(rows_sim, cols_sim) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(4, 4)) |
|
|
ax.imshow(sim_img) |
|
|
ax.imshow(sims, cmap='jet', alpha=alpha, extent=(0, sim_img.width, sim_img.height, 0)) |
|
|
ax.set_title(f"{score:.2f}") |
|
|
ax.axis('off') |
|
|
buf = BytesIO() |
|
|
fig.savefig(buf, format="png") |
|
|
buf.seek(0) |
|
|
all_imgs.append(Image.open(buf)) |
|
|
plt.close(fig) |
|
|
else: |
|
|
all_imgs.append(sim_img) |
|
|
|
|
|
return all_imgs |
|
|
|
|
|
SAMPLE_IMAGES = { |
|
|
"Blue T-Shirt": "blue-tshirt.jpg", |
|
|
"Levi's White T-Shirt": "levis tee.jpg", |
|
|
"Striped Shirt": "striped-shirt.jpg", |
|
|
} |
|
|
|
|
|
def update_input_img(sample_choice): |
|
|
if sample_choice in SAMPLE_IMAGES: |
|
|
return Image.open(SAMPLE_IMAGES[sample_choice]).convert("RGB") |
|
|
return None |
|
|
|
|
|
def get_query_image(uploaded_img): |
|
|
return uploaded_img |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown(""" |
|
|
<h1 style="font-size:36px; font-weight:bold;">Visual Fashion Search using DINOv3 Embeddings</h1> |
|
|
<p style="font-size:18px;"> This tool allows you to find similar fashion items using image embeddings. To use: <br> |
|
|
<ul style="font-size:18px;"> |
|
|
<li>Upload an image, choose a sample image or use your webcam!</li> |
|
|
<li>Select the value of k for top-k results and alpha for the overlay.</li> |
|
|
</ul> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
input_img = gr.Image(label="Upload Image", type="pil") |
|
|
sample_choice = gr.Dropdown(choices=list(SAMPLE_IMAGES.keys()), label="Or select a sample image") |
|
|
with gr.Column(): |
|
|
k_input = gr.Slider(minimum=1, maximum=50, step=1, value=5, label="Number of Results (k)") |
|
|
alpha_input = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Overlay Transparency (alpha)") |
|
|
overlay_input = gr.Checkbox(label="Show Overlay", value=False) |
|
|
|
|
|
gallery_output = gr.Gallery(label="Results", columns=5) |
|
|
btn = gr.Button("Search") |
|
|
|
|
|
sample_choice.change(fn=update_input_img, inputs=[sample_choice], outputs=[input_img]) |
|
|
|
|
|
btn.click(fn=plot_topk_images, |
|
|
inputs=[input_img, alpha_input, k_input, overlay_input], |
|
|
outputs=gallery_output) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|