|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from typing import Generator, Tuple, List, Union, Dict |
|
|
from pathlib import Path |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import re |
|
|
import io |
|
|
import matplotlib.cm as cm |
|
|
|
|
|
from colpali_engine.models import ColPali, ColPaliProcessor |
|
|
from colpali_engine.utils.torch_utils import get_torch_device |
|
|
from vidore_benchmark.interpretability.torch_utils import ( |
|
|
normalize_similarity_map_per_query_token, |
|
|
) |
|
|
from functools import lru_cache |
|
|
import logging |
|
|
|
|
|
|
|
|
class SimMapGenerator: |
|
|
""" |
|
|
Generates similarity maps based on query embeddings and image patches using the ColPali model. |
|
|
""" |
|
|
|
|
|
colormap = cm.get_cmap("viridis") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
logger: logging.Logger, |
|
|
model_name: str = "vidore/colpali-v1.2", |
|
|
n_patch: int = 32, |
|
|
): |
|
|
""" |
|
|
Initializes the SimMapGenerator class with a specified model and patch dimension. |
|
|
|
|
|
Args: |
|
|
model_name (str): The model name for loading the ColPali model. |
|
|
n_patch (int): The number of patches per dimension. |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.n_patch = n_patch |
|
|
self.device = get_torch_device("auto") |
|
|
self.logger = logger |
|
|
self.logger.info(f"Using device: {self.device}") |
|
|
self.model, self.processor = self.load_model() |
|
|
|
|
|
def load_model(self) -> Tuple[ColPali, ColPaliProcessor]: |
|
|
""" |
|
|
Loads the ColPali model and processor. |
|
|
|
|
|
Returns: |
|
|
Tuple[ColPali, ColPaliProcessor]: Loaded model and processor. |
|
|
""" |
|
|
model = ColPali.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=self.device, |
|
|
).eval() |
|
|
|
|
|
processor = ColPaliProcessor.from_pretrained(self.model_name) |
|
|
return model, processor |
|
|
|
|
|
def gen_similarity_maps( |
|
|
self, |
|
|
query: str, |
|
|
query_embs: torch.Tensor, |
|
|
token_idx_map: Dict[int, str], |
|
|
images: List[Union[Path, str]], |
|
|
vespa_sim_maps: List[Dict], |
|
|
) -> Generator[Tuple[int, str, str], None, None]: |
|
|
""" |
|
|
Generates similarity maps for the provided images and query, and returns base64-encoded blended images. |
|
|
|
|
|
Args: |
|
|
query (str): The query string. |
|
|
query_embs (torch.Tensor): Query embeddings tensor. |
|
|
token_idx_map (dict): Mapping from indices to tokens. |
|
|
images (List[Union[Path, str]]): List of image paths or base64-encoded strings. |
|
|
vespa_sim_maps (List[Dict]): List of Vespa similarity maps. |
|
|
|
|
|
Yields: |
|
|
Tuple[int, str, str]: A tuple containing the image index, selected token, and base64-encoded image. |
|
|
""" |
|
|
processed_images, original_images, original_sizes = [], [], [] |
|
|
for img in images: |
|
|
img_pil = self._load_image(img) |
|
|
original_images.append(img_pil.copy()) |
|
|
original_sizes.append(img_pil.size) |
|
|
processed_images.append(img_pil) |
|
|
|
|
|
vespa_sim_map_tensor = self._prepare_similarity_map_tensor( |
|
|
query_embs, vespa_sim_maps |
|
|
) |
|
|
similarity_map_normalized = normalize_similarity_map_per_query_token( |
|
|
vespa_sim_map_tensor |
|
|
) |
|
|
|
|
|
for idx, img in enumerate(original_images): |
|
|
for token_idx, token in token_idx_map.items(): |
|
|
if self.should_filter_token(token): |
|
|
continue |
|
|
|
|
|
sim_map = similarity_map_normalized[idx, token_idx, :, :] |
|
|
blended_img_base64 = self._blend_image( |
|
|
img, sim_map, original_sizes[idx] |
|
|
) |
|
|
yield idx, token, token_idx, blended_img_base64 |
|
|
|
|
|
def _load_image(self, img: Union[Path, str]) -> Image: |
|
|
""" |
|
|
Loads an image from a file path or a base64-encoded string. |
|
|
|
|
|
Args: |
|
|
img (Union[Path, str]): The image to load. |
|
|
|
|
|
Returns: |
|
|
Image: The loaded PIL image. |
|
|
""" |
|
|
try: |
|
|
if isinstance(img, Path): |
|
|
return Image.open(img).convert("RGB") |
|
|
elif isinstance(img, str): |
|
|
return Image.open(BytesIO(base64.b64decode(img))).convert("RGB") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to load image: {e}") |
|
|
|
|
|
def _prepare_similarity_map_tensor( |
|
|
self, query_embs: torch.Tensor, vespa_sim_maps: List[Dict] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Prepares a similarity map tensor from Vespa similarity maps. |
|
|
|
|
|
Args: |
|
|
query_embs (torch.Tensor): Query embeddings tensor. |
|
|
vespa_sim_maps (List[Dict]): List of Vespa similarity maps. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The prepared similarity map tensor. |
|
|
""" |
|
|
vespa_sim_map_tensor = torch.zeros( |
|
|
(len(vespa_sim_maps), query_embs.size(1), self.n_patch, self.n_patch) |
|
|
) |
|
|
for idx, vespa_sim_map in enumerate(vespa_sim_maps): |
|
|
for cell in vespa_sim_map["quantized"]["cells"]: |
|
|
patch = int(cell["address"]["patch"]) |
|
|
query_token = int(cell["address"]["querytoken"]) |
|
|
value = cell["value"] |
|
|
if hasattr(self.processor, "image_seq_length"): |
|
|
image_seq_length = self.processor.image_seq_length |
|
|
else: |
|
|
image_seq_length = 1024 |
|
|
|
|
|
if patch >= image_seq_length: |
|
|
continue |
|
|
vespa_sim_map_tensor[ |
|
|
idx, |
|
|
query_token, |
|
|
patch // self.n_patch, |
|
|
patch % self.n_patch, |
|
|
] = value |
|
|
return vespa_sim_map_tensor |
|
|
|
|
|
def _blend_image( |
|
|
self, img: Image, sim_map: torch.Tensor, original_size: Tuple[int, int] |
|
|
) -> str: |
|
|
""" |
|
|
Blends an image with a similarity map and encodes it to base64. |
|
|
|
|
|
Args: |
|
|
img (Image): The original image. |
|
|
sim_map (torch.Tensor): The similarity map tensor. |
|
|
original_size (Tuple[int, int]): The original size of the image. |
|
|
|
|
|
Returns: |
|
|
str: The base64-encoded blended image. |
|
|
""" |
|
|
SCALING_FACTOR = 8 |
|
|
sim_map_resolution = ( |
|
|
max(32, int(original_size[0] / SCALING_FACTOR)), |
|
|
max(32, int(original_size[1] / SCALING_FACTOR)), |
|
|
) |
|
|
|
|
|
sim_map_np = sim_map.cpu().float().numpy() |
|
|
sim_map_img = Image.fromarray(sim_map_np).resize( |
|
|
sim_map_resolution, resample=Image.BICUBIC |
|
|
) |
|
|
sim_map_resized_np = np.array(sim_map_img, dtype=np.float32) |
|
|
sim_map_normalized = self._normalize_sim_map(sim_map_resized_np) |
|
|
|
|
|
heatmap = self.colormap(sim_map_normalized) |
|
|
heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)).convert("RGBA") |
|
|
|
|
|
buffer = io.BytesIO() |
|
|
heatmap_img.save(buffer, format="PNG") |
|
|
return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
|
@staticmethod |
|
|
def _normalize_sim_map(sim_map: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Normalizes a similarity map to range [0, 1]. |
|
|
|
|
|
Args: |
|
|
sim_map (np.ndarray): The similarity map. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: The normalized similarity map. |
|
|
""" |
|
|
sim_map_min, sim_map_max = sim_map.min(), sim_map.max() |
|
|
if sim_map_max - sim_map_min > 1e-6: |
|
|
return (sim_map - sim_map_min) / (sim_map_max - sim_map_min) |
|
|
return np.zeros_like(sim_map) |
|
|
|
|
|
@staticmethod |
|
|
def should_filter_token(token: str) -> bool: |
|
|
""" |
|
|
Determines if a token should be filtered out based on predefined patterns. |
|
|
|
|
|
The function filters out tokens that: |
|
|
|
|
|
- Start with '<' (e.g., '<bos>') |
|
|
- Consist entirely of whitespace |
|
|
- Are purely punctuation (excluding tokens that contain digits or start with 'β') |
|
|
- Start with an underscore '_' |
|
|
- Exactly match the word 'Question' |
|
|
- Are exactly the single character 'β' |
|
|
|
|
|
Output of test: |
|
|
Token: '2' | False |
|
|
Token: '0' | False |
|
|
Token: '2' | False |
|
|
Token: '3' | False |
|
|
Token: 'β2' | False |
|
|
Token: 'βhi' | False |
|
|
Token: 'norwegian' | False |
|
|
Token: 'unlisted' | False |
|
|
Token: '<bos>' | True |
|
|
Token: 'Question' | True |
|
|
Token: ':' | True |
|
|
Token: '<pad>' | True |
|
|
Token: '\n' | True |
|
|
Token: 'β' | True |
|
|
Token: '?' | True |
|
|
Token: ')' | True |
|
|
Token: '%' | True |
|
|
Token: '/)' | True |
|
|
|
|
|
|
|
|
Args: |
|
|
token (str): The token to check. |
|
|
|
|
|
Returns: |
|
|
bool: True if the token should be filtered out, False otherwise. |
|
|
""" |
|
|
pattern = re.compile( |
|
|
r"^<.*$|^\s+$|^(?!.*\d)(?!β)[^\w\s]+$|^_.*$|^Question$|^β$" |
|
|
) |
|
|
return bool(pattern.match(token)) |
|
|
|
|
|
@lru_cache(maxsize=128) |
|
|
def get_query_embeddings_and_token_map( |
|
|
self, query: str |
|
|
) -> Tuple[torch.Tensor, dict]: |
|
|
""" |
|
|
Retrieves query embeddings and a token index map. |
|
|
|
|
|
Args: |
|
|
query (str): The query string. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, dict]: Query embeddings and token index map. |
|
|
""" |
|
|
inputs = self.processor.process_queries([query]).to(self.model.device) |
|
|
with torch.no_grad(): |
|
|
q_emb = self.model(**inputs).to("cpu")[0] |
|
|
|
|
|
query_tokens = self.processor.tokenizer.tokenize( |
|
|
self.processor.decode(inputs.input_ids[0]) |
|
|
) |
|
|
idx_to_token = {idx: token for idx, token in enumerate(query_tokens)} |
|
|
return q_emb, idx_to_token |
|
|
|