File size: 3,081 Bytes
f5f3b3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from transformers import AutoImageProcessor, AutoModel
import numpy as np
import matplotlib.pyplot as plt
import os
import tqdm 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m",token=os.environ.get("HF_TOKEN"))
model = AutoModel.from_pretrained(
    "facebook/dinov3-vits16-pretrain-lvd1689m",
    dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa",
    token=os.environ.get("HF_TOKEN")
)
model.eval()
model.to(device)

def get_patch_embeddings(img, ps=16):
    inputs = processor(images=img, return_tensors="pt").to(device, torch.float16)
    B, C, H, W = inputs["pixel_values"].shape
    rows, cols = H // ps, W // ps

    with torch.no_grad():
        out = model(**inputs)

    hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()

    # remove CLS + register tokens
    n_patches = rows * cols
    patch_embs = hs[-n_patches:, :].reshape(rows, cols, -1)

    X = patch_embs.reshape(-1, patch_embs.shape[-1])
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)

    return patch_embs

def image_query_by_patch(query_patch_emb, patch_embeddings, k):

    query_patch_emb = query_patch_emb.half().unsqueeze(0)
    scores = []
    for img_embs in patch_embeddings:
        sim = (img_embs @ query_patch_emb.T).max().item()
        scores.append(sim)

    topk_indices = torch.tensor(scores).topk(k).indices
    topk_scores = torch.tensor(scores)[topk_indices]

    return topk_indices, topk_scores

def generate_patch_embeddings_dataset(dataset):

    batch_size = 16
    image_dir = "deepfashion_images"
    os.makedirs(image_dir, exist_ok=True)

    all_patch_embeddings = []
    all_image_paths = []

    with torch.no_grad():
        for i in tqdm(range(0, len(dataset), batch_size)):
            batch = dataset.select(range(i, min(i + batch_size, len(dataset))))

            imgs = []
            paths = []
            for j, item in enumerate(batch):
                img = item["image"]
                path = os.path.join(image_dir, f"img_{i+j}.jpg")
                img.save(path)
                paths.append(path)
                imgs.append(img)

            for img in imgs:
                patch_embs = get_patch_embeddings(img, ps=16, device=device)
                rows, cols, dim = patch_embs.shape

                patch_embs_flat = patch_embs.reshape(-1, dim)

                patch_embs_flat = patch_embs_flat / (np.linalg.norm(patch_embs_flat, axis=1, keepdims=True) + 1e-8)

                all_patch_embeddings.append(
                    torch.from_numpy(patch_embs_flat).to(device=device, dtype=torch.float16)
                )


            all_image_paths.extend(paths)

    torch.save({
        "patch_embeddings": all_patch_embeddings, 
        "image_paths": all_image_paths
    }, "deep_fashion_patch_embeddings.pt")

    print(f"Saved patch-level embeddings for {len(all_image_paths)} images!")