NeerjaK's picture
Upload 3 files
f5f3b3d verified
import torch
from transformers import AutoImageProcessor, AutoModel
import numpy as np
import matplotlib.pyplot as plt
import os
import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m",token=os.environ.get("HF_TOKEN"))
model = AutoModel.from_pretrained(
"facebook/dinov3-vits16-pretrain-lvd1689m",
dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa",
token=os.environ.get("HF_TOKEN")
)
model.eval()
model.to(device)
def get_patch_embeddings(img, ps=16):
inputs = processor(images=img, return_tensors="pt").to(device, torch.float16)
B, C, H, W = inputs["pixel_values"].shape
rows, cols = H // ps, W // ps
with torch.no_grad():
out = model(**inputs)
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
# remove CLS + register tokens
n_patches = rows * cols
patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
X = patch_embs.reshape(-1, patch_embs.shape[-1])
Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
return patch_embs
def image_query_by_patch(query_patch_emb, patch_embeddings, k):
query_patch_emb = query_patch_emb.half().unsqueeze(0)
scores = []
for img_embs in patch_embeddings:
sim = (img_embs @ query_patch_emb.T).max().item()
scores.append(sim)
topk_indices = torch.tensor(scores).topk(k).indices
topk_scores = torch.tensor(scores)[topk_indices]
return topk_indices, topk_scores
def generate_patch_embeddings_dataset(dataset):
batch_size = 16
image_dir = "deepfashion_images"
os.makedirs(image_dir, exist_ok=True)
all_patch_embeddings = []
all_image_paths = []
with torch.no_grad():
for i in tqdm(range(0, len(dataset), batch_size)):
batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
imgs = []
paths = []
for j, item in enumerate(batch):
img = item["image"]
path = os.path.join(image_dir, f"img_{i+j}.jpg")
img.save(path)
paths.append(path)
imgs.append(img)
for img in imgs:
patch_embs = get_patch_embeddings(img, ps=16, device=device)
rows, cols, dim = patch_embs.shape
patch_embs_flat = patch_embs.reshape(-1, dim)
patch_embs_flat = patch_embs_flat / (np.linalg.norm(patch_embs_flat, axis=1, keepdims=True) + 1e-8)
all_patch_embeddings.append(
torch.from_numpy(patch_embs_flat).to(device=device, dtype=torch.float16)
)
all_image_paths.extend(paths)
torch.save({
"patch_embeddings": all_patch_embeddings,
"image_paths": all_image_paths
}, "deep_fashion_patch_embeddings.pt")
print(f"Saved patch-level embeddings for {len(all_image_paths)} images!")