|
|
import torch
|
|
|
from transformers import AutoImageProcessor, AutoModel
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
import os
|
|
|
import tqdm
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m",token=os.environ.get("HF_TOKEN"))
|
|
|
model = AutoModel.from_pretrained(
|
|
|
"facebook/dinov3-vits16-pretrain-lvd1689m",
|
|
|
dtype=torch.float16,
|
|
|
device_map="auto",
|
|
|
attn_implementation="sdpa",
|
|
|
token=os.environ.get("HF_TOKEN")
|
|
|
)
|
|
|
model.eval()
|
|
|
model.to(device)
|
|
|
|
|
|
def get_patch_embeddings(img, ps=16):
|
|
|
inputs = processor(images=img, return_tensors="pt").to(device, torch.float16)
|
|
|
B, C, H, W = inputs["pixel_values"].shape
|
|
|
rows, cols = H // ps, W // ps
|
|
|
|
|
|
with torch.no_grad():
|
|
|
out = model(**inputs)
|
|
|
|
|
|
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
n_patches = rows * cols
|
|
|
patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)
|
|
|
|
|
|
X = patch_embs.reshape(-1, patch_embs.shape[-1])
|
|
|
Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
|
|
|
|
|
|
return patch_embs
|
|
|
|
|
|
def image_query_by_patch(query_patch_emb, patch_embeddings, k):
|
|
|
|
|
|
query_patch_emb = query_patch_emb.half().unsqueeze(0)
|
|
|
scores = []
|
|
|
for img_embs in patch_embeddings:
|
|
|
sim = (img_embs @ query_patch_emb.T).max().item()
|
|
|
scores.append(sim)
|
|
|
|
|
|
topk_indices = torch.tensor(scores).topk(k).indices
|
|
|
topk_scores = torch.tensor(scores)[topk_indices]
|
|
|
|
|
|
return topk_indices, topk_scores
|
|
|
|
|
|
def generate_patch_embeddings_dataset(dataset):
|
|
|
|
|
|
batch_size = 16
|
|
|
image_dir = "deepfashion_images"
|
|
|
os.makedirs(image_dir, exist_ok=True)
|
|
|
|
|
|
all_patch_embeddings = []
|
|
|
all_image_paths = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for i in tqdm(range(0, len(dataset), batch_size)):
|
|
|
batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
|
|
|
|
|
|
imgs = []
|
|
|
paths = []
|
|
|
for j, item in enumerate(batch):
|
|
|
img = item["image"]
|
|
|
path = os.path.join(image_dir, f"img_{i+j}.jpg")
|
|
|
img.save(path)
|
|
|
paths.append(path)
|
|
|
imgs.append(img)
|
|
|
|
|
|
for img in imgs:
|
|
|
patch_embs = get_patch_embeddings(img, ps=16, device=device)
|
|
|
rows, cols, dim = patch_embs.shape
|
|
|
|
|
|
patch_embs_flat = patch_embs.reshape(-1, dim)
|
|
|
|
|
|
patch_embs_flat = patch_embs_flat / (np.linalg.norm(patch_embs_flat, axis=1, keepdims=True) + 1e-8)
|
|
|
|
|
|
all_patch_embeddings.append(
|
|
|
torch.from_numpy(patch_embs_flat).to(device=device, dtype=torch.float16)
|
|
|
)
|
|
|
|
|
|
|
|
|
all_image_paths.extend(paths)
|
|
|
|
|
|
torch.save({
|
|
|
"patch_embeddings": all_patch_embeddings,
|
|
|
"image_paths": all_image_paths
|
|
|
}, "deep_fashion_patch_embeddings.pt")
|
|
|
|
|
|
print(f"Saved patch-level embeddings for {len(all_image_paths)} images!")
|
|
|
|