|
|
"""Video processor class for Molmo2""" |
|
|
from typing import TYPE_CHECKING, Tuple, List, Optional, Union, Dict, Any |
|
|
import numpy as np |
|
|
import einops |
|
|
import torch |
|
|
import torchvision.transforms |
|
|
from torchvision.transforms import InterpolationMode |
|
|
from torchvision.transforms.functional import convert_image_dtype |
|
|
|
|
|
from transformers.image_utils import ( |
|
|
IMAGENET_STANDARD_MEAN, |
|
|
IMAGENET_STANDARD_STD, |
|
|
ImageInput, |
|
|
PILImageResampling, |
|
|
SizeDict, |
|
|
) |
|
|
from transformers.video_utils import ( |
|
|
VideoInput, |
|
|
valid_videos, |
|
|
make_batched_videos, |
|
|
) |
|
|
from transformers.processing_utils import Unpack, VideosKwargs |
|
|
from transformers.video_processing_utils import BaseVideoProcessor |
|
|
from transformers.utils import logging |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.utils import TensorType, logging, to_numpy |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from transformers.utils import TensorType, logging |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def normalize_image( |
|
|
image: np.ndarray, |
|
|
image_mean: List[float], |
|
|
image_std: List[float], |
|
|
) -> np.ndarray: |
|
|
image -= np.array(image_mean, dtype=np.float32)[None, None, :] |
|
|
image /= np.array(image_std, dtype=np.float32)[None, None, :] |
|
|
return image |
|
|
|
|
|
|
|
|
def resize_image( |
|
|
image: np.ndarray, |
|
|
desired_output_size: List[int], |
|
|
resample: PILImageResampling, |
|
|
) -> np.ndarray: |
|
|
if len(image.shape) == 3: |
|
|
is_video = False |
|
|
image = torch.permute(torch.from_numpy(image), [2, 0, 1]) |
|
|
else: |
|
|
is_video = True |
|
|
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2]) |
|
|
dtype = image.dtype |
|
|
if torch.is_floating_point(image): |
|
|
in_min = 0.0 |
|
|
in_max = 1.0 |
|
|
resized = torchvision.transforms.Resize( |
|
|
desired_output_size, |
|
|
resample, |
|
|
antialias=False, |
|
|
)(image) |
|
|
resized = torch.clip(resized, 0.0, 1.0).to(dtype) |
|
|
else: |
|
|
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype) |
|
|
in_min = 0.0 |
|
|
in_max = 255.0 |
|
|
resized = torchvision.transforms.Resize( |
|
|
desired_output_size, |
|
|
resample, |
|
|
antialias=False, |
|
|
)(image) |
|
|
resized = torch.clip(resized, 0, 255).to(dtype) |
|
|
|
|
|
resized = resized.to(torch.float32) |
|
|
resized = (resized - in_min) / (in_max - in_min) |
|
|
|
|
|
if is_video: |
|
|
resized = torch.permute(resized, [0, 2, 3, 1]).numpy() |
|
|
else: |
|
|
resized = torch.permute(resized, [1, 2, 0]).numpy() |
|
|
|
|
|
return resized |
|
|
|
|
|
|
|
|
def build_resized_image( |
|
|
image: np.ndarray, |
|
|
base_image_input_size: List[int], |
|
|
resample: PILImageResampling, |
|
|
image_mean: List[float], |
|
|
image_std: List[float], |
|
|
image_patch_size: int, |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
resized = resize_image( |
|
|
image, base_image_input_size, resample, |
|
|
) |
|
|
resized = normalize_image(resized, image_mean, image_std) |
|
|
if len(resized.shape) == 3: |
|
|
resized = np.expand_dims(resized, 0) |
|
|
crop_patch_w = base_image_input_size[1] // image_patch_size |
|
|
crop_patch_h = base_image_input_size[0] // image_patch_size |
|
|
resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w]) |
|
|
return resized, resize_idx |
|
|
|
|
|
|
|
|
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray: |
|
|
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]""" |
|
|
if len(array.shape) == 3: |
|
|
n_crops, h, w = array.shape |
|
|
h_patches = h//patch_size |
|
|
w_patches = w//patch_size |
|
|
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size]) |
|
|
array = np.transpose(array, [0, 1, 3, 2, 4]) |
|
|
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size]) |
|
|
return array |
|
|
else: |
|
|
n_crops, h, w, c = array.shape |
|
|
h_patches = h//patch_size |
|
|
w_patches = w//patch_size |
|
|
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c]) |
|
|
array = np.transpose(array, [0, 1, 3, 2, 4, 5]) |
|
|
array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c]) |
|
|
return array |
|
|
|
|
|
|
|
|
def arange_for_pooling( |
|
|
idx_arr: np.ndarray, |
|
|
pool_h: int, |
|
|
pool_w: int, |
|
|
) -> np.ndarray: |
|
|
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0] |
|
|
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1] |
|
|
idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]], |
|
|
mode='constant',constant_values=-1) |
|
|
return einops.rearrange( |
|
|
idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) |
|
|
|
|
|
|
|
|
def image_to_patches_and_grids( |
|
|
image: ImageInput, |
|
|
base_image_input_size: List[int], |
|
|
resample: PILImageResampling, |
|
|
image_mean: List[float], |
|
|
image_std: List[float], |
|
|
image_patch_size: int, |
|
|
image_pooling_w: int, |
|
|
image_pooling_h: int, |
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
|
""" |
|
|
:return image_grids, the shape of each image after pooling |
|
|
:return crops, the image crops to processes with the ViT |
|
|
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the |
|
|
patches in `crops` to pool for that token, masked with -1 |
|
|
""" |
|
|
if isinstance(base_image_input_size, int): |
|
|
base_image_input_size = (base_image_input_size, base_image_input_size) |
|
|
|
|
|
pooling_w = image_pooling_w |
|
|
pooling_h = image_pooling_h |
|
|
|
|
|
resized, resize_idx = build_resized_image( |
|
|
image, |
|
|
base_image_input_size, |
|
|
resample, |
|
|
image_mean, |
|
|
image_std, |
|
|
image_patch_size, |
|
|
) |
|
|
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) |
|
|
h, w = pooling_idx.shape[:2] |
|
|
pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) |
|
|
image_grid = [h, w] |
|
|
return ( |
|
|
image_grid, |
|
|
batch_pixels_to_patches(resized, image_patch_size), |
|
|
pooling_idx, |
|
|
) |
|
|
|
|
|
|
|
|
class Molmo2VideoProcessorKwargs(VideosKwargs, total=False): |
|
|
patch_size: Optional[int] |
|
|
pooling_size: Optional[List[int]] |
|
|
|
|
|
|
|
|
class Molmo2VideoProcessor(BaseVideoProcessor): |
|
|
resample = PILImageResampling.BILINEAR |
|
|
size = {"height": 378, "width": 378} |
|
|
image_mean = IMAGENET_STANDARD_MEAN |
|
|
image_std = IMAGENET_STANDARD_STD |
|
|
do_resize = True |
|
|
do_rescale = True |
|
|
do_normalize = True |
|
|
do_convert_rgb = True |
|
|
patch_size = 14 |
|
|
pooling_size = [3, 3] |
|
|
valid_kwargs = Molmo2VideoProcessorKwargs |
|
|
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"] |
|
|
|
|
|
def __init__(self, **kwargs: Unpack[Molmo2VideoProcessorKwargs]): |
|
|
super().__init__(**kwargs) |
|
|
if self.size is not None and ( |
|
|
self.size.get("height", None) is None or self.size.get("width", None) is None |
|
|
): |
|
|
raise ValueError("size must contain 'height' and 'width' keys.") |
|
|
|
|
|
def _further_process_kwargs( |
|
|
self, |
|
|
size: Optional[SizeDict] = None, |
|
|
**kwargs, |
|
|
) -> dict: |
|
|
""" |
|
|
Update kwargs that need further processing before being validated |
|
|
Can be overridden by subclasses to customize the processing of kwargs. |
|
|
""" |
|
|
if size is not None and ("height" not in size or "width" not in size): |
|
|
raise ValueError("size must contain 'height' and 'width' keys.") |
|
|
|
|
|
return super()._further_process_kwargs(size=size, **kwargs) |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
videos: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], |
|
|
size: Optional[dict[str, int]] = None, |
|
|
resample: Optional[PILImageResampling] = None, |
|
|
image_mean: Optional[Union[float, list[float]]] = None, |
|
|
image_std: Optional[Union[float, list[float]]] = None, |
|
|
do_convert_rgb: Optional[bool] = None, |
|
|
patch_size: Optional[int] = None, |
|
|
pooling_size: Optional[List[int]] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Preprocess a video for the model. |
|
|
Args: |
|
|
videos (`VideoInput`): |
|
|
Video to preprocess. |
|
|
size (`dict[str, int]`, *optional*, defaults to `self.size`): |
|
|
Size of the image after resizing. |
|
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`): |
|
|
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only |
|
|
has an effect if `do_resize` is set to `True`. |
|
|
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): |
|
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
|
|
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): |
|
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
|
|
`True`. |
|
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
|
|
Whether to convert the image to RGB. |
|
|
patch_size (`int`, *optional*, defaults to `self.patch_size`): |
|
|
The spatial patch size of the vision encoder. |
|
|
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`): |
|
|
The pooling size of the vision adapter. |
|
|
return_tensors (`str` or `TensorType`, *optional*): |
|
|
The type of tensors to return. Can be one of: |
|
|
- Unset: Return a list of `np.ndarray`. |
|
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. |
|
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. |
|
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. |
|
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. |
|
|
|
|
|
Returns: |
|
|
A `BatchFeature` containing the following keys: |
|
|
- `pixel_values_videos`: The preprocessed videos. |
|
|
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`. |
|
|
- `video_grids`: The video grids. |
|
|
""" |
|
|
videos = make_batched_videos(videos) |
|
|
|
|
|
if size is not None: |
|
|
if "height" not in size or "width" not in size: |
|
|
raise ValueError("size must contain 'height' and 'width' keys.") |
|
|
else: |
|
|
size = {**self.size} |
|
|
|
|
|
base_image_input_size = [size["height"], size["width"]] |
|
|
|
|
|
resample = resample or self.resample |
|
|
image_mean = image_mean or self.image_mean |
|
|
image_std = image_std or self.image_std |
|
|
do_convert_rgb = do_convert_rgb or self.do_convert_rgb |
|
|
|
|
|
patch_size = patch_size or self.patch_size |
|
|
pooling_size = pooling_size or self.pooling_size |
|
|
|
|
|
image_pooling_h, image_pooling_w = pooling_size |
|
|
|
|
|
|
|
|
videos = [to_numpy(video) for video in videos] |
|
|
|
|
|
batch_grids = [] |
|
|
batch_crops = [] |
|
|
batch_pooled_patches_idx = [] |
|
|
|
|
|
for video in videos: |
|
|
all_crops = [] |
|
|
pooled_patches_idx = [] |
|
|
|
|
|
for frame in video: |
|
|
image_grid, crops, pooled_idx = image_to_patches_and_grids( |
|
|
frame, |
|
|
base_image_input_size, |
|
|
resample, |
|
|
image_mean, |
|
|
image_std, |
|
|
patch_size, |
|
|
image_pooling_w, |
|
|
image_pooling_h, |
|
|
) |
|
|
offset = sum(np.prod(x.shape[:2]) for x in all_crops) |
|
|
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx) |
|
|
pooled_patches_idx.append(pooled_idx_with_offset) |
|
|
all_crops.append(crops) |
|
|
|
|
|
video_grid = np.array([len(video), image_grid[0], image_grid[1]]) |
|
|
all_crops = np.concatenate(all_crops, 0) |
|
|
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0) |
|
|
|
|
|
batch_grids.append(video_grid) |
|
|
batch_crops.append(all_crops) |
|
|
batch_pooled_patches_idx.append(pooled_patches_idx) |
|
|
|
|
|
video_grids = np.stack(batch_grids, 0) |
|
|
pixel_values_videos = np.concatenate(batch_crops, 0) |
|
|
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0) |
|
|
|
|
|
data =dict( |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
video_token_pooling=video_token_pooling, |
|
|
video_grids=video_grids, |
|
|
) |
|
|
|
|
|
return BatchFeature(data, tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
Molmo2VideoProcessor.register_for_auto_class() |