|
|
import csv |
|
|
import gc |
|
|
import io |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
from contextlib import contextmanager |
|
|
from random import shuffle |
|
|
from threading import Thread |
|
|
|
|
|
import albumentations |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as transforms |
|
|
from decord import VideoReader |
|
|
from einops import rearrange |
|
|
from func_timeout import FunctionTimedOut, func_timeout |
|
|
from packaging import version as pver |
|
|
from PIL import Image |
|
|
from torch.utils.data import BatchSampler, Sampler |
|
|
from torch.utils.data.dataset import Dataset |
|
|
|
|
|
VIDEO_READER_TIMEOUT = 20 |
|
|
|
|
|
def get_random_mask(shape, image_start_only=False): |
|
|
f, c, h, w = shape |
|
|
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) |
|
|
|
|
|
if not image_start_only: |
|
|
if f != 1: |
|
|
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) |
|
|
else: |
|
|
mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) |
|
|
if mask_index == 0: |
|
|
center_x = torch.randint(0, w, (1,)).item() |
|
|
center_y = torch.randint(0, h, (1,)).item() |
|
|
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() |
|
|
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() |
|
|
|
|
|
start_x = max(center_x - block_size_x // 2, 0) |
|
|
end_x = min(center_x + block_size_x // 2, w) |
|
|
start_y = max(center_y - block_size_y // 2, 0) |
|
|
end_y = min(center_y + block_size_y // 2, h) |
|
|
mask[:, :, start_y:end_y, start_x:end_x] = 1 |
|
|
elif mask_index == 1: |
|
|
mask[:, :, :, :] = 1 |
|
|
elif mask_index == 2: |
|
|
mask_frame_index = np.random.randint(1, 5) |
|
|
mask[mask_frame_index:, :, :, :] = 1 |
|
|
elif mask_index == 3: |
|
|
mask_frame_index = np.random.randint(1, 5) |
|
|
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 |
|
|
elif mask_index == 4: |
|
|
center_x = torch.randint(0, w, (1,)).item() |
|
|
center_y = torch.randint(0, h, (1,)).item() |
|
|
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() |
|
|
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() |
|
|
|
|
|
start_x = max(center_x - block_size_x // 2, 0) |
|
|
end_x = min(center_x + block_size_x // 2, w) |
|
|
start_y = max(center_y - block_size_y // 2, 0) |
|
|
end_y = min(center_y + block_size_y // 2, h) |
|
|
|
|
|
mask_frame_before = np.random.randint(0, f // 2) |
|
|
mask_frame_after = np.random.randint(f // 2, f) |
|
|
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 |
|
|
elif mask_index == 5: |
|
|
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) |
|
|
elif mask_index == 6: |
|
|
num_frames_to_mask = random.randint(1, max(f // 2, 1)) |
|
|
frames_to_mask = random.sample(range(f), num_frames_to_mask) |
|
|
|
|
|
for i in frames_to_mask: |
|
|
block_height = random.randint(1, h // 4) |
|
|
block_width = random.randint(1, w // 4) |
|
|
top_left_y = random.randint(0, h - block_height) |
|
|
top_left_x = random.randint(0, w - block_width) |
|
|
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 |
|
|
elif mask_index == 7: |
|
|
center_x = torch.randint(0, w, (1,)).item() |
|
|
center_y = torch.randint(0, h, (1,)).item() |
|
|
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() |
|
|
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() |
|
|
|
|
|
for i in range(h): |
|
|
for j in range(w): |
|
|
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: |
|
|
mask[:, :, i, j] = 1 |
|
|
elif mask_index == 8: |
|
|
center_x = torch.randint(0, w, (1,)).item() |
|
|
center_y = torch.randint(0, h, (1,)).item() |
|
|
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() |
|
|
for i in range(h): |
|
|
for j in range(w): |
|
|
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: |
|
|
mask[:, :, i, j] = 1 |
|
|
elif mask_index == 9: |
|
|
for idx in range(f): |
|
|
if np.random.rand() > 0.5: |
|
|
mask[idx, :, :, :] = 1 |
|
|
else: |
|
|
raise ValueError(f"The mask_index {mask_index} is not define") |
|
|
else: |
|
|
if f != 1: |
|
|
mask[1:, :, :, :] = 1 |
|
|
else: |
|
|
mask[:, :, :, :] = 1 |
|
|
return mask |
|
|
|
|
|
class Camera(object): |
|
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
|
|
""" |
|
|
def __init__(self, entry): |
|
|
fx, fy, cx, cy = entry[1:5] |
|
|
self.fx = fx |
|
|
self.fy = fy |
|
|
self.cx = cx |
|
|
self.cy = cy |
|
|
w2c_mat = np.array(entry[7:]).reshape(3, 4) |
|
|
w2c_mat_4x4 = np.eye(4) |
|
|
w2c_mat_4x4[:3, :] = w2c_mat |
|
|
self.w2c_mat = w2c_mat_4x4 |
|
|
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) |
|
|
|
|
|
def custom_meshgrid(*args): |
|
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
|
|
""" |
|
|
|
|
|
if pver.parse(torch.__version__) < pver.parse('1.10'): |
|
|
return torch.meshgrid(*args) |
|
|
else: |
|
|
return torch.meshgrid(*args, indexing='ij') |
|
|
|
|
|
def get_relative_pose(cam_params): |
|
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
|
|
""" |
|
|
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
|
|
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
|
|
cam_to_origin = 0 |
|
|
target_cam_c2w = np.array([ |
|
|
[1, 0, 0, 0], |
|
|
[0, 1, 0, -cam_to_origin], |
|
|
[0, 0, 1, 0], |
|
|
[0, 0, 0, 1] |
|
|
]) |
|
|
abs2rel = target_cam_c2w @ abs_w2cs[0] |
|
|
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] |
|
|
ret_poses = np.array(ret_poses, dtype=np.float32) |
|
|
return ret_poses |
|
|
|
|
|
def ray_condition(K, c2w, H, W, device): |
|
|
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
B = K.shape[0] |
|
|
|
|
|
j, i = custom_meshgrid( |
|
|
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
|
|
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
|
|
) |
|
|
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
|
|
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
|
|
|
|
|
fx, fy, cx, cy = K.chunk(4, dim=-1) |
|
|
|
|
|
zs = torch.ones_like(i) |
|
|
xs = (i - cx) / fx * zs |
|
|
ys = (j - cy) / fy * zs |
|
|
zs = zs.expand_as(ys) |
|
|
|
|
|
directions = torch.stack((xs, ys, zs), dim=-1) |
|
|
directions = directions / directions.norm(dim=-1, keepdim=True) |
|
|
|
|
|
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) |
|
|
rays_o = c2w[..., :3, 3] |
|
|
rays_o = rays_o[:, :, None].expand_as(rays_d) |
|
|
|
|
|
rays_dxo = torch.cross(rays_o, rays_d) |
|
|
plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
|
|
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) |
|
|
|
|
|
return plucker |
|
|
|
|
|
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): |
|
|
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
|
|
""" |
|
|
with open(pose_file_path, 'r') as f: |
|
|
poses = f.readlines() |
|
|
|
|
|
poses = [pose.strip().split(' ') for pose in poses[1:]] |
|
|
cam_params = [[float(x) for x in pose] for pose in poses] |
|
|
if return_poses: |
|
|
return cam_params |
|
|
else: |
|
|
cam_params = [Camera(cam_param) for cam_param in cam_params] |
|
|
|
|
|
sample_wh_ratio = width / height |
|
|
pose_wh_ratio = original_pose_width / original_pose_height |
|
|
|
|
|
if pose_wh_ratio > sample_wh_ratio: |
|
|
resized_ori_w = height * pose_wh_ratio |
|
|
for cam_param in cam_params: |
|
|
cam_param.fx = resized_ori_w * cam_param.fx / width |
|
|
else: |
|
|
resized_ori_h = width / pose_wh_ratio |
|
|
for cam_param in cam_params: |
|
|
cam_param.fy = resized_ori_h * cam_param.fy / height |
|
|
|
|
|
intrinsic = np.asarray([[cam_param.fx * width, |
|
|
cam_param.fy * height, |
|
|
cam_param.cx * width, |
|
|
cam_param.cy * height] |
|
|
for cam_param in cam_params], dtype=np.float32) |
|
|
|
|
|
K = torch.as_tensor(intrinsic)[None] |
|
|
c2ws = get_relative_pose(cam_params) |
|
|
c2ws = torch.as_tensor(c2ws)[None] |
|
|
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() |
|
|
plucker_embedding = plucker_embedding[None] |
|
|
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] |
|
|
return plucker_embedding |
|
|
|
|
|
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): |
|
|
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
|
|
""" |
|
|
cam_params = [Camera(cam_param) for cam_param in cam_params] |
|
|
|
|
|
sample_wh_ratio = width / height |
|
|
pose_wh_ratio = original_pose_width / original_pose_height |
|
|
|
|
|
if pose_wh_ratio > sample_wh_ratio: |
|
|
resized_ori_w = height * pose_wh_ratio |
|
|
for cam_param in cam_params: |
|
|
cam_param.fx = resized_ori_w * cam_param.fx / width |
|
|
else: |
|
|
resized_ori_h = width / pose_wh_ratio |
|
|
for cam_param in cam_params: |
|
|
cam_param.fy = resized_ori_h * cam_param.fy / height |
|
|
|
|
|
intrinsic = np.asarray([[cam_param.fx * width, |
|
|
cam_param.fy * height, |
|
|
cam_param.cx * width, |
|
|
cam_param.cy * height] |
|
|
for cam_param in cam_params], dtype=np.float32) |
|
|
|
|
|
K = torch.as_tensor(intrinsic)[None] |
|
|
c2ws = get_relative_pose(cam_params) |
|
|
c2ws = torch.as_tensor(c2ws)[None] |
|
|
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() |
|
|
plucker_embedding = plucker_embedding[None] |
|
|
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] |
|
|
return plucker_embedding |
|
|
|
|
|
def derive_ground_object_from_instruction(instruction: str) -> str: |
|
|
s = (instruction or '').strip() |
|
|
if not s: |
|
|
return 'the target area' |
|
|
s = s.rstrip('.').strip() |
|
|
|
|
|
|
|
|
swap_patterns = [ |
|
|
r"\breplace\s+(.*?)\s+(?:with|by)\b", |
|
|
r"\bswap\s+(.*?)\s+with\b", |
|
|
] |
|
|
for pat in swap_patterns: |
|
|
m = re.search(pat, s, flags=re.IGNORECASE) |
|
|
if m: |
|
|
phrase = m.group(1).strip(' .,:;') |
|
|
if phrase: |
|
|
return phrase |
|
|
|
|
|
|
|
|
m = re.search(r"\b(?:remove|delete|erase|eliminate)\s+(.*?)(?:\s+(?:from|in|at|on|over|under|near|by)\b|[.,;]|$)", s, flags=re.IGNORECASE) |
|
|
if m: |
|
|
phrase = m.group(1).strip(' .,:;') |
|
|
if phrase: |
|
|
return phrase |
|
|
|
|
|
|
|
|
if re.search(r"^\s*(?:add|insert)\b", s, flags=re.IGNORECASE): |
|
|
return 'the target area' |
|
|
|
|
|
|
|
|
m = re.search(r"\b(?:change|make)\s+(?:(the|a|an)\s+)?([A-Za-z][A-Za-z0-9\-]*)", s, flags=re.IGNORECASE) |
|
|
if m: |
|
|
det = m.group(1) or '' |
|
|
noun = m.group(2) |
|
|
phrase = (det + ' ' + noun).strip() |
|
|
return phrase |
|
|
|
|
|
return 'the target area' |
|
|
|
|
|
class ImageVideoSampler(BatchSampler): |
|
|
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
|
|
|
|
|
Args: |
|
|
sampler (Sampler): Base sampler. |
|
|
dataset (Dataset): Dataset providing data information. |
|
|
batch_size (int): Size of mini-batch. |
|
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
|
its size would be less than ``batch_size``. |
|
|
aspect_ratios (dict): The predefined aspect ratios. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
sampler: Sampler, |
|
|
dataset: Dataset, |
|
|
batch_size: int, |
|
|
drop_last: bool = False |
|
|
) -> None: |
|
|
if not isinstance(sampler, Sampler): |
|
|
raise TypeError('sampler should be an instance of ``Sampler``, ' |
|
|
f'but got {sampler}') |
|
|
if not isinstance(batch_size, int) or batch_size <= 0: |
|
|
raise ValueError('batch_size should be a positive integer value, ' |
|
|
f'but got batch_size={batch_size}') |
|
|
self.sampler = sampler |
|
|
self.dataset = dataset |
|
|
self.batch_size = batch_size |
|
|
self.drop_last = drop_last |
|
|
|
|
|
|
|
|
self.bucket = {'image':[], 'video':[]} |
|
|
|
|
|
def __iter__(self): |
|
|
for idx in self.sampler: |
|
|
content_type = self.dataset.dataset[idx].get('type', 'image') |
|
|
self.bucket[content_type].append(idx) |
|
|
|
|
|
|
|
|
if len(self.bucket['video']) == self.batch_size: |
|
|
bucket = self.bucket['video'] |
|
|
yield bucket[:] |
|
|
del bucket[:] |
|
|
elif len(self.bucket['image']) == self.batch_size: |
|
|
bucket = self.bucket['image'] |
|
|
yield bucket[:] |
|
|
del bucket[:] |
|
|
|
|
|
@contextmanager |
|
|
def VideoReader_contextmanager(*args, **kwargs): |
|
|
vr = VideoReader(*args, **kwargs) |
|
|
try: |
|
|
yield vr |
|
|
finally: |
|
|
del vr |
|
|
gc.collect() |
|
|
|
|
|
def get_video_reader_batch(video_reader, batch_index): |
|
|
frames = video_reader.get_batch(batch_index).asnumpy() |
|
|
return frames |
|
|
|
|
|
def resize_frame(frame, target_short_side): |
|
|
h, w, _ = frame.shape |
|
|
if h < w: |
|
|
if target_short_side > h: |
|
|
return frame |
|
|
new_h = target_short_side |
|
|
new_w = int(target_short_side * w / h) |
|
|
else: |
|
|
if target_short_side > w: |
|
|
return frame |
|
|
new_w = target_short_side |
|
|
new_h = int(target_short_side * h / w) |
|
|
|
|
|
resized_frame = cv2.resize(frame, (new_w, new_h)) |
|
|
return resized_frame |
|
|
|
|
|
class VideoEditDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ann_path, |
|
|
data_root=None, |
|
|
video_sample_height: int = None, |
|
|
video_sample_width: int = None, |
|
|
video_sample_stride=1, |
|
|
video_sample_n_frames=65, |
|
|
source_frames=33, |
|
|
edit_frames=32, |
|
|
text_drop_ratio=0.1, |
|
|
enable_bucket=False, |
|
|
enable_inpaint=False, |
|
|
instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}", |
|
|
): |
|
|
dataset = json.load(open(ann_path)) |
|
|
if isinstance(dataset, dict): |
|
|
new_dataset = [] |
|
|
for vid_id, info in dataset.items(): |
|
|
text_content = info["edit_instruction"] |
|
|
new_dataset.append({ |
|
|
"original_video": info["original_video"], |
|
|
"edited_video": info["edited_video"], |
|
|
"text": text_content, |
|
|
"type": info.get("type", "video"), |
|
|
|
|
|
"resolution": info.get("resolution", None) |
|
|
}) |
|
|
dataset = new_dataset |
|
|
|
|
|
self.data_root = data_root |
|
|
self.dataset = dataset |
|
|
self.length = len(self.dataset) |
|
|
|
|
|
self.source_frames = source_frames |
|
|
self.edit_frames = edit_frames |
|
|
self.video_sample_n_frames = video_sample_n_frames |
|
|
|
|
|
self.instruction_template = instruction_template |
|
|
self.enable_bucket = enable_bucket |
|
|
self.text_drop_ratio = text_drop_ratio |
|
|
self.enable_inpaint = enable_inpaint |
|
|
self.video_sample_stride = video_sample_stride |
|
|
|
|
|
|
|
|
if enable_bucket: |
|
|
self.video_sample_height = None |
|
|
self.video_sample_width = None |
|
|
else: |
|
|
self.video_sample_height = video_sample_height |
|
|
self.video_sample_width = video_sample_width |
|
|
|
|
|
def load_video_pair(self, original_path, edited_path): |
|
|
"""加载视频对,保持原始分辨率用于bucket training""" |
|
|
if self.data_root is not None: |
|
|
original_path = os.path.join(self.data_root, original_path) |
|
|
edited_path = os.path.join(self.data_root, edited_path) |
|
|
|
|
|
with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \ |
|
|
VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader: |
|
|
|
|
|
|
|
|
orig_length = len(orig_reader) |
|
|
edit_length = len(edit_reader) |
|
|
min_length = min(orig_length, edit_length) |
|
|
|
|
|
|
|
|
start_idx = 0 |
|
|
|
|
|
orig_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, orig_length - 1), |
|
|
self.source_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
edit_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, edit_length - 1), |
|
|
self.edit_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
|
|
|
orig_frames = get_video_reader_batch(orig_reader, orig_indices) |
|
|
edit_frames = get_video_reader_batch(edit_reader, edit_indices) |
|
|
|
|
|
|
|
|
def resize_and_center_crop_batch(frames_np, target_h, target_w): |
|
|
resized = [] |
|
|
for i in range(frames_np.shape[0]): |
|
|
frame = frames_np[i] |
|
|
h, w = frame.shape[0], frame.shape[1] |
|
|
scale = max(target_h / h, target_w / w) |
|
|
new_h = int(round(h * scale)) |
|
|
new_w = int(round(w * scale)) |
|
|
frame_resized = cv2.resize(frame, (new_w, new_h)) |
|
|
y0 = max((new_h - target_h) // 2, 0) |
|
|
x0 = max((new_w - target_w) // 2, 0) |
|
|
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w] |
|
|
resized.append(frame_cropped) |
|
|
return np.stack(resized, axis=0) |
|
|
|
|
|
oh, ow = orig_frames.shape[1], orig_frames.shape[2] |
|
|
eh, ew = edit_frames.shape[1], edit_frames.shape[2] |
|
|
target_h = min(oh, eh) |
|
|
target_w = min(ow, ew) |
|
|
if (oh != target_h or ow != target_w): |
|
|
orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w) |
|
|
if (eh != target_h or ew != target_w): |
|
|
edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w) |
|
|
|
|
|
|
|
|
if self.enable_bucket: |
|
|
return np.concatenate([orig_frames, edit_frames], axis=0) |
|
|
else: |
|
|
|
|
|
orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
return torch.cat([orig_frames, edit_frames], dim=0) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
pixel_values = self.load_video_pair( |
|
|
data_info['original_video'], |
|
|
data_info['edited_video'] |
|
|
) |
|
|
|
|
|
|
|
|
text = data_info['text'] |
|
|
if self.instruction_template and "{edit_instruction}" in self.instruction_template: |
|
|
text = self.instruction_template.format(edit_instruction=text) |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
sample = { |
|
|
"pixel_values": pixel_values, |
|
|
"text": text, |
|
|
"data_type": "video", |
|
|
"idx": idx, |
|
|
} |
|
|
|
|
|
|
|
|
if self.enable_inpaint and not self.enable_bucket: |
|
|
|
|
|
pass |
|
|
|
|
|
return sample |
|
|
|
|
|
except Exception as e: |
|
|
try: |
|
|
print( |
|
|
f"Error loading video pair: {e}\n" |
|
|
f" original={os.path.join(self.data_root, data_info.get('original_video','')) if self.data_root else data_info.get('original_video','')}\n" |
|
|
f" edited ={os.path.join(self.data_root, data_info.get('edited_video','')) if self.data_root else data_info.get('edited_video','')}" |
|
|
) |
|
|
except Exception: |
|
|
print(f"Error loading video pair: {e}") |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
class VideoEditReasoningDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ann_path, |
|
|
data_root=None, |
|
|
video_sample_height: int = None, |
|
|
video_sample_width: int = None, |
|
|
video_sample_stride=1, |
|
|
video_sample_n_frames=65, |
|
|
source_frames=33, |
|
|
reasoning_frames=4, |
|
|
edit_frames=32, |
|
|
text_drop_ratio=0.1, |
|
|
enable_bucket=False, |
|
|
enable_inpaint=False, |
|
|
instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instrction}, and finally the same scene but {edit_instruction}", |
|
|
): |
|
|
dataset = json.load(open(ann_path)) |
|
|
if isinstance(dataset, dict): |
|
|
new_dataset = [] |
|
|
for vid_id, info in dataset.items(): |
|
|
text_content = info.get("edit_instruction", info.get("text", "")) |
|
|
|
|
|
grounded_key = "grounded_video" if "grounded_video" in info else "ground_video" |
|
|
new_dataset.append({ |
|
|
"original_video": info["original_video"], |
|
|
"grounded_video": info[grounded_key], |
|
|
"edited_video": info["edited_video"], |
|
|
"text": text_content, |
|
|
"edit_instruction": text_content, |
|
|
"type": info.get("type", "video"), |
|
|
"resolution": info.get("resolution", None), |
|
|
}) |
|
|
dataset = new_dataset |
|
|
|
|
|
self.data_root = data_root |
|
|
self.dataset = dataset |
|
|
self.length = len(self.dataset) |
|
|
|
|
|
self.source_frames = source_frames |
|
|
self.reasoning_frames = reasoning_frames |
|
|
self.edit_frames = edit_frames |
|
|
self.video_sample_n_frames = video_sample_n_frames |
|
|
|
|
|
self.instruction_template = instruction_template |
|
|
self.enable_bucket = enable_bucket |
|
|
self.text_drop_ratio = text_drop_ratio |
|
|
self.enable_inpaint = enable_inpaint |
|
|
self.video_sample_stride = video_sample_stride |
|
|
|
|
|
if enable_bucket: |
|
|
self.video_sample_height = None |
|
|
self.video_sample_width = None |
|
|
else: |
|
|
self.video_sample_height = video_sample_height |
|
|
self.video_sample_width = video_sample_width |
|
|
|
|
|
def load_video_pair(self, original_path, grounded_path, edited_path): |
|
|
if self.data_root is not None: |
|
|
original_path = os.path.join(self.data_root, original_path) |
|
|
grounded_path = os.path.join(self.data_root, grounded_path) |
|
|
edited_path = os.path.join(self.data_root, edited_path) |
|
|
|
|
|
with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \ |
|
|
VideoReader_contextmanager(grounded_path, num_threads=2) as ground_reader, \ |
|
|
VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader: |
|
|
|
|
|
orig_length = len(orig_reader) |
|
|
ground_length = len(ground_reader) |
|
|
edit_length = len(edit_reader) |
|
|
|
|
|
start_idx = 0 |
|
|
|
|
|
orig_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, max(orig_length - 1, 0)), |
|
|
self.source_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
|
|
|
interval = 8 |
|
|
ground_indices_full = np.arange(0, max(ground_length, 1), interval, dtype=int) |
|
|
if len(ground_indices_full) == 0: |
|
|
ground_indices = np.array([0] * self.reasoning_frames, dtype=int) |
|
|
else: |
|
|
ground_indices = ground_indices_full[: self.reasoning_frames] |
|
|
if len(ground_indices) < self.reasoning_frames: |
|
|
pad_value = ground_indices[-1] if len(ground_indices) > 0 else 0 |
|
|
ground_indices = np.pad( |
|
|
ground_indices, (0, self.reasoning_frames - len(ground_indices)), constant_values=pad_value |
|
|
) |
|
|
|
|
|
edit_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, max(edit_length - 1, 0)), |
|
|
self.edit_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
orig_frames = get_video_reader_batch(orig_reader, orig_indices) |
|
|
ground_frames = get_video_reader_batch(ground_reader, ground_indices) |
|
|
edit_frames = get_video_reader_batch(edit_reader, edit_indices) |
|
|
|
|
|
def resize_and_center_crop_batch(frames_np, target_h, target_w): |
|
|
resized = [] |
|
|
for i in range(frames_np.shape[0]): |
|
|
frame = frames_np[i] |
|
|
h, w = frame.shape[0], frame.shape[1] |
|
|
scale = max(target_h / h, target_w / w) |
|
|
new_h = int(round(h * scale)) |
|
|
new_w = int(round(w * scale)) |
|
|
frame_resized = cv2.resize(frame, (new_w, new_h)) |
|
|
y0 = max((new_h - target_h) // 2, 0) |
|
|
x0 = max((new_w - target_w) // 2, 0) |
|
|
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w] |
|
|
resized.append(frame_cropped) |
|
|
return np.stack(resized, axis=0) |
|
|
|
|
|
oh, ow = orig_frames.shape[1], orig_frames.shape[2] |
|
|
gh, gw = ground_frames.shape[1], ground_frames.shape[2] |
|
|
eh, ew = edit_frames.shape[1], edit_frames.shape[2] |
|
|
target_h = min(oh, gh, eh) |
|
|
target_w = min(ow, gw, ew) |
|
|
if (oh != target_h or ow != target_w): |
|
|
orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w) |
|
|
if (gh != target_h or gw != target_w): |
|
|
ground_frames = resize_and_center_crop_batch(ground_frames, target_h, target_w) |
|
|
if (eh != target_h or ew != target_w): |
|
|
edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w) |
|
|
|
|
|
if self.enable_bucket: |
|
|
return np.concatenate([orig_frames, ground_frames, edit_frames], axis=0) |
|
|
else: |
|
|
orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
ground_frames = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
return torch.cat([orig_frames, ground_frames, edit_frames], dim=0) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
|
|
|
while True: |
|
|
try: |
|
|
pixel_values = self.load_video_pair( |
|
|
data_info['original_video'], |
|
|
data_info.get('grounded_video', data_info.get('ground_video')), |
|
|
data_info['edited_video'], |
|
|
) |
|
|
|
|
|
|
|
|
edit_text = data_info.get('edit_instruction', data_info.get('text', '')) |
|
|
ground_instr = derive_ground_object_from_instruction(edit_text) |
|
|
|
|
|
text = edit_text |
|
|
if self.instruction_template: |
|
|
text = self.instruction_template.format(edit_instruction=edit_text, ground_instrction=ground_instr) |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
sample = { |
|
|
"pixel_values": pixel_values, |
|
|
"text": text, |
|
|
"data_type": "video", |
|
|
"idx": idx, |
|
|
} |
|
|
|
|
|
if self.enable_inpaint and not self.enable_bucket: |
|
|
pass |
|
|
|
|
|
return sample |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading video triplet: {e}") |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
class ImageVideoDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ann_path, data_root=None, |
|
|
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, |
|
|
image_sample_size=512, |
|
|
video_repeat=0, |
|
|
text_drop_ratio=0.1, |
|
|
enable_bucket=False, |
|
|
video_length_drop_start=0.0, |
|
|
video_length_drop_end=1.0, |
|
|
enable_inpaint=False, |
|
|
return_file_name=False, |
|
|
): |
|
|
|
|
|
print(f"loading annotations from {ann_path} ...") |
|
|
if ann_path.endswith('.csv'): |
|
|
with open(ann_path, 'r') as csvfile: |
|
|
dataset = list(csv.DictReader(csvfile)) |
|
|
elif ann_path.endswith('.json'): |
|
|
dataset = json.load(open(ann_path)) |
|
|
|
|
|
self.data_root = data_root |
|
|
|
|
|
|
|
|
if video_repeat > 0: |
|
|
self.dataset = [] |
|
|
for data in dataset: |
|
|
if data.get('type', 'image') != 'video': |
|
|
self.dataset.append(data) |
|
|
|
|
|
for _ in range(video_repeat): |
|
|
for data in dataset: |
|
|
if data.get('type', 'image') == 'video': |
|
|
self.dataset.append(data) |
|
|
else: |
|
|
self.dataset = dataset |
|
|
del dataset |
|
|
|
|
|
self.length = len(self.dataset) |
|
|
print(f"data scale: {self.length}") |
|
|
|
|
|
self.enable_bucket = enable_bucket |
|
|
self.text_drop_ratio = text_drop_ratio |
|
|
self.enable_inpaint = enable_inpaint |
|
|
self.return_file_name = return_file_name |
|
|
|
|
|
self.video_length_drop_start = video_length_drop_start |
|
|
self.video_length_drop_end = video_length_drop_end |
|
|
|
|
|
|
|
|
self.video_sample_stride = video_sample_stride |
|
|
self.video_sample_n_frames = video_sample_n_frames |
|
|
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
|
|
self.video_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) |
|
|
self.image_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(self.image_sample_size)), |
|
|
transforms.CenterCrop(self.image_sample_size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) |
|
|
|
|
|
def get_batch(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
|
|
|
if data_info.get('type', 'image')=='video': |
|
|
video_id, text = data_info['file_path'], data_info['text'] |
|
|
|
|
|
if self.data_root is None: |
|
|
video_dir = video_id |
|
|
else: |
|
|
video_dir = os.path.join(self.data_root, video_id) |
|
|
|
|
|
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: |
|
|
min_sample_n_frames = min( |
|
|
self.video_sample_n_frames, |
|
|
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) |
|
|
) |
|
|
if min_sample_n_frames == 0: |
|
|
raise ValueError(f"No Frames in video.") |
|
|
|
|
|
video_length = int(self.video_length_drop_end * len(video_reader)) |
|
|
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) |
|
|
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 |
|
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) |
|
|
|
|
|
try: |
|
|
sample_args = (video_reader, batch_index) |
|
|
pixel_values = func_timeout( |
|
|
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
|
|
) |
|
|
resized_frames = [] |
|
|
for i in range(len(pixel_values)): |
|
|
frame = pixel_values[i] |
|
|
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
|
|
resized_frames.append(resized_frame) |
|
|
pixel_values = np.array(resized_frames) |
|
|
except FunctionTimedOut: |
|
|
raise ValueError(f"Read {idx} timeout.") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
|
|
|
|
|
if not self.enable_bucket: |
|
|
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() |
|
|
pixel_values = pixel_values / 255. |
|
|
del video_reader |
|
|
else: |
|
|
pixel_values = pixel_values |
|
|
|
|
|
if not self.enable_bucket: |
|
|
pixel_values = self.video_transforms(pixel_values) |
|
|
|
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
return pixel_values, text, 'video', video_dir |
|
|
else: |
|
|
image_path, text = data_info['file_path'], data_info['text'] |
|
|
if self.data_root is not None: |
|
|
image_path = os.path.join(self.data_root, image_path) |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
if not self.enable_bucket: |
|
|
image = self.image_transforms(image).unsqueeze(0) |
|
|
else: |
|
|
image = np.expand_dims(np.array(image), 0) |
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
return image, text, 'image', image_path |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
data_type = data_info.get('type', 'image') |
|
|
while True: |
|
|
sample = {} |
|
|
try: |
|
|
data_info_local = self.dataset[idx % len(self.dataset)] |
|
|
data_type_local = data_info_local.get('type', 'image') |
|
|
if data_type_local != data_type: |
|
|
raise ValueError("data_type_local != data_type") |
|
|
|
|
|
pixel_values, name, data_type, file_path = self.get_batch(idx) |
|
|
sample["pixel_values"] = pixel_values |
|
|
sample["text"] = name |
|
|
sample["data_type"] = data_type |
|
|
sample["idx"] = idx |
|
|
if self.return_file_name: |
|
|
sample["file_name"] = os.path.basename(file_path) |
|
|
|
|
|
if len(sample) > 0: |
|
|
break |
|
|
except Exception as e: |
|
|
print(e, self.dataset[idx % len(self.dataset)]) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
class ImageVideoEditDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ann_path, |
|
|
data_root=None, |
|
|
video_sample_size=512, |
|
|
video_sample_stride=1, |
|
|
source_frames=33, |
|
|
target_frames=32, |
|
|
text_drop_ratio=0.1, |
|
|
enable_bucket=False, |
|
|
enable_inpaint=False, |
|
|
video_length_drop_start=0.0, |
|
|
video_length_drop_end=1.0, |
|
|
instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}", |
|
|
): |
|
|
dataset = json.load(open(ann_path)) |
|
|
if isinstance(dataset, dict): |
|
|
new_dataset = [] |
|
|
for _, info in dataset.items(): |
|
|
|
|
|
data_type = info.get("type", "video") |
|
|
entry = dict(info) |
|
|
|
|
|
if "edit_instruction" in entry: |
|
|
entry["text"] = entry["edit_instruction"] |
|
|
elif "instruction" in entry: |
|
|
entry["text"] = entry["instruction"] |
|
|
elif "text" not in entry: |
|
|
entry["text"] = "" |
|
|
|
|
|
|
|
|
if entry["text"] is None: |
|
|
entry["text"] = "" |
|
|
|
|
|
|
|
|
|
|
|
if data_type == "video": |
|
|
entry["file_path"] = entry.get("original_video", "") |
|
|
else: |
|
|
entry["file_path"] = entry.get("original_image", "") |
|
|
|
|
|
new_dataset.append(entry) |
|
|
dataset = new_dataset |
|
|
|
|
|
self.data_root = data_root |
|
|
self.dataset = dataset |
|
|
self.length = len(self.dataset) |
|
|
|
|
|
|
|
|
self.video_sample_stride = video_sample_stride |
|
|
self.source_frames = source_frames |
|
|
self.target_frames = target_frames |
|
|
self.video_length_drop_start = video_length_drop_start |
|
|
self.video_length_drop_end = video_length_drop_end |
|
|
|
|
|
|
|
|
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
|
|
self.video_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.image_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
self.instruction_template = instruction_template |
|
|
self.enable_bucket = enable_bucket |
|
|
self.text_drop_ratio = text_drop_ratio |
|
|
self.enable_inpaint = enable_inpaint |
|
|
|
|
|
|
|
|
self.larger_side_of_image_and_video = min(self.video_sample_size) |
|
|
|
|
|
def _resize_and_center_crop_batch(self, frames_np, target_h, target_w): |
|
|
resized = [] |
|
|
for i in range(frames_np.shape[0]): |
|
|
frame = frames_np[i] |
|
|
h, w = frame.shape[0], frame.shape[1] |
|
|
scale = max(target_h / h, target_w / w) |
|
|
new_h = int(round(h * scale)) |
|
|
new_w = int(round(w * scale)) |
|
|
frame_resized = cv2.resize(frame, (new_w, new_h)) |
|
|
y0 = max((new_h - target_h) // 2, 0) |
|
|
x0 = max((new_w - target_w) // 2, 0) |
|
|
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w] |
|
|
resized.append(frame_cropped) |
|
|
return np.stack(resized, axis=0) |
|
|
|
|
|
def _resize_and_center_crop_image(self, image_np, target_h, target_w): |
|
|
h, w = image_np.shape[0], image_np.shape[1] |
|
|
scale = max(target_h / h, target_w / w) |
|
|
new_h = int(round(h * scale)) |
|
|
new_w = int(round(w * scale)) |
|
|
image_resized = cv2.resize(image_np, (new_w, new_h)) |
|
|
y0 = max((new_h - target_h) // 2, 0) |
|
|
x0 = max((new_w - target_w) // 2, 0) |
|
|
image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w] |
|
|
return image_cropped |
|
|
|
|
|
def get_batch(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
|
|
|
data_type = data_info.get('type', 'video') |
|
|
|
|
|
|
|
|
raw_text = data_info.get('text', '') |
|
|
if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()): |
|
|
|
|
|
raw_text = "the content has been modified" |
|
|
|
|
|
|
|
|
if self.instruction_template and "{edit_instruction}" in self.instruction_template: |
|
|
text = self.instruction_template.format(edit_instruction=raw_text) |
|
|
else: |
|
|
text = raw_text |
|
|
|
|
|
if data_type == 'video': |
|
|
|
|
|
src_rel, tgt_rel = data_info['original_video'], data_info['edited_video'] |
|
|
|
|
|
if self.data_root is not None: |
|
|
src_path = os.path.join(self.data_root, src_rel) |
|
|
tgt_path = os.path.join(self.data_root, tgt_rel) |
|
|
else: |
|
|
src_path = src_rel |
|
|
tgt_path = tgt_rel |
|
|
|
|
|
|
|
|
from decord import cpu |
|
|
with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \ |
|
|
VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader: |
|
|
|
|
|
|
|
|
src_length = len(src_reader) |
|
|
tgt_length = len(tgt_reader) |
|
|
|
|
|
|
|
|
if src_length < self.source_frames: |
|
|
raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}") |
|
|
if tgt_length < self.target_frames: |
|
|
raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}") |
|
|
|
|
|
|
|
|
start_idx = 0 |
|
|
|
|
|
src_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1), |
|
|
self.source_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
tgt_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1), |
|
|
self.target_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices)) |
|
|
tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices)) |
|
|
except FunctionTimedOut: |
|
|
raise ValueError(f"Read {idx} timeout.") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to extract frames from pair. Error is {e}.") |
|
|
|
|
|
|
|
|
sh, sw = src_frames.shape[1], src_frames.shape[2] |
|
|
th, tw = tgt_frames.shape[1], tgt_frames.shape[2] |
|
|
target_h = min(sh, th) |
|
|
target_w = min(sw, tw) |
|
|
if (sh != target_h or sw != target_w): |
|
|
src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w) |
|
|
if (th != target_h or tw != target_w): |
|
|
tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w) |
|
|
|
|
|
if not self.enable_bucket: |
|
|
src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
|
|
|
src_tensor = self.video_transforms(src_tensor) |
|
|
tgt_tensor = self.video_transforms(tgt_tensor) |
|
|
else: |
|
|
src_tensor = src_frames |
|
|
tgt_tensor = tgt_frames |
|
|
|
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
return src_tensor, tgt_tensor, text, 'video' |
|
|
else: |
|
|
|
|
|
src_img_rel = data_info.get('original_image') |
|
|
tgt_img_rel = data_info.get('edited_image') |
|
|
if src_img_rel is None or tgt_img_rel is None: |
|
|
raise ValueError('Missing original_image/edited_image for image sample') |
|
|
|
|
|
if self.data_root is not None: |
|
|
src_img_path = os.path.join(self.data_root, src_img_rel) |
|
|
tgt_img_path = os.path.join(self.data_root, tgt_img_rel) |
|
|
else: |
|
|
src_img_path = src_img_rel |
|
|
tgt_img_path = tgt_img_rel |
|
|
|
|
|
src_img = Image.open(src_img_path).convert('RGB') |
|
|
tgt_img = Image.open(tgt_img_path).convert('RGB') |
|
|
|
|
|
if not self.enable_bucket: |
|
|
|
|
|
src_tensor = self.image_transforms(src_img).unsqueeze(0) |
|
|
tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) |
|
|
else: |
|
|
|
|
|
src_tensor = np.expand_dims(np.array(src_img), axis=0) |
|
|
tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
return src_tensor, tgt_tensor, text, 'image' |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
data_type = data_info.get('type', 'video') |
|
|
while True: |
|
|
sample = {} |
|
|
try: |
|
|
data_info_local = self.dataset[idx % len(self.dataset)] |
|
|
data_type_local = data_info_local.get('type', 'video') |
|
|
if data_type_local != data_type: |
|
|
raise ValueError("data_type_local != data_type") |
|
|
|
|
|
src_vals, tgt_vals, name, data_type = self.get_batch(idx) |
|
|
if data_type == 'video': |
|
|
sample["pixel_values_src_video"] = src_vals |
|
|
sample["pixel_values_tgt_video"] = tgt_vals |
|
|
else: |
|
|
sample["pixel_values_src_image"] = src_vals |
|
|
sample["pixel_values_tgt_image"] = tgt_vals |
|
|
sample["text"] = name |
|
|
sample["data_type"] = data_type |
|
|
sample["idx"] = idx |
|
|
|
|
|
if len(sample) > 0: |
|
|
break |
|
|
except Exception as e: |
|
|
print(e, self.dataset[idx % len(self.dataset)]) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
class ImageVideoCoTDataset(Dataset): |
|
|
""" |
|
|
Dataset for Chain-of-Thought (CoT) style image/video editing. |
|
|
- For videos: loads original_video, grounded_video, and edited_video (3-part) |
|
|
- For images: loads original_image and edited_image (2-part, same as ImageVideoEditDataset) |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
ann_path, |
|
|
data_root=None, |
|
|
video_sample_size=512, |
|
|
video_sample_stride=1, |
|
|
source_frames=33, |
|
|
reasoning_frames=4, |
|
|
target_frames=33, |
|
|
text_drop_ratio=0.1, |
|
|
enable_bucket=False, |
|
|
enable_inpaint=False, |
|
|
video_length_drop_start=0.0, |
|
|
video_length_drop_end=1.0, |
|
|
instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instruction}, and finally the same scene but {edit_instruction}", |
|
|
enable_gradual_ground=False, |
|
|
enable_gray_red_mask=False, |
|
|
enable_gray_black_background=False, |
|
|
enable_gray_alpha_overlay=False, |
|
|
gray_alpha=0.5, |
|
|
gray_intensity_range=(96, 160), |
|
|
gray_tolerance=12, |
|
|
): |
|
|
dataset = json.load(open(ann_path)) |
|
|
if isinstance(dataset, dict): |
|
|
new_dataset = [] |
|
|
for _, info in dataset.items(): |
|
|
data_type = info.get("type", "video") |
|
|
entry = dict(info) |
|
|
|
|
|
|
|
|
if "edit_instruction" in entry: |
|
|
entry["text"] = entry["edit_instruction"] |
|
|
elif "instruction" in entry: |
|
|
entry["text"] = entry["instruction"] |
|
|
elif "text" not in entry: |
|
|
entry["text"] = "" |
|
|
|
|
|
|
|
|
if entry["text"] is None: |
|
|
entry["text"] = "" |
|
|
|
|
|
|
|
|
if data_type == "video": |
|
|
entry["file_path"] = entry.get("original_video", "") |
|
|
else: |
|
|
entry["file_path"] = entry.get("original_image", "") |
|
|
|
|
|
new_dataset.append(entry) |
|
|
dataset = new_dataset |
|
|
|
|
|
self.data_root = data_root |
|
|
self.dataset = dataset |
|
|
self.length = len(self.dataset) |
|
|
|
|
|
|
|
|
self.video_sample_stride = video_sample_stride |
|
|
self.source_frames = source_frames |
|
|
self.reasoning_frames = reasoning_frames |
|
|
self.target_frames = target_frames |
|
|
self.video_length_drop_start = video_length_drop_start |
|
|
self.video_length_drop_end = video_length_drop_end |
|
|
|
|
|
|
|
|
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
|
|
self.video_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.image_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
self.instruction_template = instruction_template |
|
|
self.enable_bucket = enable_bucket |
|
|
self.text_drop_ratio = text_drop_ratio |
|
|
self.enable_inpaint = enable_inpaint |
|
|
self.enable_gradual_ground = enable_gradual_ground |
|
|
|
|
|
enabled_modes = int(bool(enable_gray_red_mask)) + int(bool(enable_gray_black_background)) + int(bool(enable_gray_alpha_overlay)) |
|
|
if enabled_modes > 1: |
|
|
raise ValueError("enable_gray_red_mask, enable_gray_black_background and enable_gray_alpha_overlay cannot be enabled simultaneously.") |
|
|
self.enable_gray_red_mask = enable_gray_red_mask |
|
|
self.enable_gray_black_background = enable_gray_black_background |
|
|
self.enable_gray_alpha_overlay = enable_gray_alpha_overlay |
|
|
self.gray_alpha = float(gray_alpha) |
|
|
if not (0.0 <= self.gray_alpha <= 1.0): |
|
|
raise ValueError("gray_alpha must be in [0,1].") |
|
|
if not isinstance(gray_intensity_range, (list, tuple)) or len(gray_intensity_range) != 2: |
|
|
raise ValueError("gray_intensity_range must contain exactly two values (min and max intensity).") |
|
|
self.gray_intensity_range = (int(gray_intensity_range[0]), int(gray_intensity_range[1])) |
|
|
if self.gray_intensity_range[0] > self.gray_intensity_range[1]: |
|
|
raise ValueError("gray_intensity_range min value cannot be greater than max value.") |
|
|
self.gray_tolerance = int(gray_tolerance) |
|
|
|
|
|
|
|
|
self.larger_side_of_image_and_video = min(self.video_sample_size) |
|
|
|
|
|
def _resize_and_center_crop_batch(self, frames_np, target_h, target_w): |
|
|
resized = [] |
|
|
for i in range(frames_np.shape[0]): |
|
|
frame = frames_np[i] |
|
|
h, w = frame.shape[0], frame.shape[1] |
|
|
scale = max(target_h / h, target_w / w) |
|
|
new_h = int(round(h * scale)) |
|
|
new_w = int(round(w * scale)) |
|
|
frame_resized = cv2.resize(frame, (new_w, new_h)) |
|
|
y0 = max((new_h - target_h) // 2, 0) |
|
|
x0 = max((new_w - target_w) // 2, 0) |
|
|
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w] |
|
|
resized.append(frame_cropped) |
|
|
return np.stack(resized, axis=0) |
|
|
|
|
|
def _resize_and_center_crop_image(self, image_np, target_h, target_w): |
|
|
h, w = image_np.shape[0], image_np.shape[1] |
|
|
scale = max(target_h / h, target_w / w) |
|
|
new_h = int(round(h * scale)) |
|
|
new_w = int(round(w * scale)) |
|
|
image_resized = cv2.resize(image_np, (new_w, new_h)) |
|
|
y0 = max((new_h - target_h) // 2, 0) |
|
|
x0 = max((new_w - target_w) // 2, 0) |
|
|
image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w] |
|
|
return image_cropped |
|
|
|
|
|
def _derive_ground_instruction(self, edit_instruction_text: str) -> str: |
|
|
"""Derive grounded object phrase from instruction using shared rules.""" |
|
|
return derive_ground_object_from_instruction(edit_instruction_text) |
|
|
|
|
|
def _ensure_same_size_pair(self, img_a: np.ndarray, img_b: np.ndarray) -> tuple: |
|
|
"""Resize img_b to img_a's size if needed to enable per-pixel interpolation.""" |
|
|
ha, wa = img_a.shape[:2] |
|
|
hb, wb = img_b.shape[:2] |
|
|
if (ha, wa) == (hb, wb): |
|
|
return img_a, img_b |
|
|
resized_b = cv2.resize(img_b, (wa, ha), interpolation=cv2.INTER_LINEAR) |
|
|
return img_a, resized_b |
|
|
|
|
|
def _interpolate_ground_frames(self, ground_first: np.ndarray, target_first: np.ndarray, |
|
|
total_steps: int = 16, |
|
|
pick_indices: tuple = (0, 4, 8, 12)) -> np.ndarray: |
|
|
""" |
|
|
Create grounding frames by linearly interpolating between the first frame of |
|
|
the grounding video and the first frame of the edited video, then picking |
|
|
specific indices. |
|
|
Returns array of shape (len(pick_indices), H, W, 3) in uint8. |
|
|
""" |
|
|
a_np, b_np = self._ensure_same_size_pair(ground_first, target_first) |
|
|
|
|
|
a_t = torch.from_numpy(a_np).float() / 255.0 |
|
|
b_t = torch.from_numpy(b_np).float() / 255.0 |
|
|
|
|
|
a_t = a_t.permute(2, 0, 1).contiguous() |
|
|
b_t = b_t.permute(2, 0, 1).contiguous() |
|
|
|
|
|
c, h, w = a_t.shape |
|
|
pair = torch.stack([a_t, b_t], dim=0) |
|
|
pair_chw_t = pair.permute(1, 2, 3, 0).contiguous() |
|
|
seq = pair_chw_t.view(1, c * h * w, 2) |
|
|
with torch.no_grad(): |
|
|
seq_interp = F.interpolate(seq, size=int(total_steps), mode="linear", align_corners=True) |
|
|
seq_interp = seq_interp.view(c, h, w, int(total_steps)).permute(3, 0, 1, 2).contiguous() |
|
|
|
|
|
out_frames = [] |
|
|
t_steps = int(total_steps) |
|
|
for idx in pick_indices: |
|
|
safe_idx = max(0, min(int(idx), t_steps - 1)) |
|
|
img = (seq_interp[safe_idx].clamp(0.0, 1.0) * 255.0).byte().permute(1, 2, 0).cpu().numpy() |
|
|
out_frames.append(img) |
|
|
return np.stack(out_frames, axis=0) |
|
|
|
|
|
def _build_gray_mask(self, frame: np.ndarray) -> np.ndarray: |
|
|
"""Detect gray regions in a frame using intensity range and tolerance.""" |
|
|
frame_float = frame.astype(np.float32) |
|
|
if frame_float.max() <= 1.0: |
|
|
frame_float = frame_float * 255.0 |
|
|
channel_max = frame_float.max(axis=2) |
|
|
channel_min = frame_float.min(axis=2) |
|
|
min_intensity, max_intensity = self.gray_intensity_range |
|
|
tone_flatness = channel_max - channel_min |
|
|
mask = tone_flatness <= float(self.gray_tolerance) |
|
|
mask &= channel_max >= float(min_intensity) |
|
|
mask &= channel_max <= float(max_intensity) |
|
|
return mask |
|
|
|
|
|
def _apply_gray_region_effect(self, frames_np: np.ndarray, mode: str) -> np.ndarray: |
|
|
"""Apply requested effect on detected gray regions for a batch of frames.""" |
|
|
processed_frames = [] |
|
|
for frame in frames_np: |
|
|
mask = self._build_gray_mask(frame) |
|
|
if not np.any(mask): |
|
|
processed_frames.append(frame) |
|
|
continue |
|
|
frame_out = frame.copy() |
|
|
if np.issubdtype(frame_out.dtype, np.floating) and frame_out.max() <= 1.0: |
|
|
red_value = np.array([1.0, 0.0, 0.0], dtype=frame_out.dtype) |
|
|
else: |
|
|
red_value = np.array([255, 0, 0], dtype=frame_out.dtype) |
|
|
if mode == "red": |
|
|
frame_out[mask] = red_value |
|
|
else: |
|
|
frame_out[:] = 0 |
|
|
frame_out[mask] = frame[mask] |
|
|
processed_frames.append(frame_out) |
|
|
return np.stack(processed_frames, axis=0) |
|
|
|
|
|
def _apply_gray_overlay_from_reference(self, src_frames_np: np.ndarray, ref_frames_np: np.ndarray, |
|
|
alpha: float = 0.5, gray_value: float = 0.5, num_frames: int = 4) -> np.ndarray: |
|
|
""" |
|
|
Detect gray regions on ref frames, and overlay gray with alpha onto the |
|
|
first `num_frames` frames of src frames at the same positions. |
|
|
""" |
|
|
n = min(int(num_frames), int(src_frames_np.shape[0]), int(ref_frames_np.shape[0])) |
|
|
if n <= 0: |
|
|
return src_frames_np |
|
|
out = src_frames_np.copy() |
|
|
a = float(alpha) |
|
|
a = 0.0 if a < 0.0 else (1.0 if a > 1.0 else a) |
|
|
gv = float(gray_value) |
|
|
gv = 0.0 if gv < 0.0 else (1.0 if gv > 1.0 else gv) |
|
|
for i in range(n): |
|
|
mask = self._build_gray_mask(ref_frames_np[i]) |
|
|
if not np.any(mask): |
|
|
continue |
|
|
src = out[i] |
|
|
|
|
|
if np.issubdtype(src.dtype, np.floating): |
|
|
f = src.astype(np.float32) |
|
|
if f.max() > 1.0: |
|
|
f = np.clip(f / 255.0, 0.0, 1.0) |
|
|
back_to_uint8 = False |
|
|
else: |
|
|
f = src.astype(np.float32) / 255.0 |
|
|
back_to_uint8 = True |
|
|
gray_color = np.array([gv, gv, gv], dtype=np.float32) |
|
|
|
|
|
f[mask] = (1.0 - a) * f[mask] + a * gray_color |
|
|
if back_to_uint8: |
|
|
out[i] = (f * 255.0).clip(0, 255).astype(src.dtype) |
|
|
else: |
|
|
out[i] = f.astype(src.dtype) |
|
|
return out |
|
|
|
|
|
def get_batch(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
data_type = data_info.get('type', 'video') |
|
|
|
|
|
|
|
|
raw_text = data_info.get('text', '') |
|
|
if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()): |
|
|
raw_text = "the content has been modified" |
|
|
|
|
|
if data_type == 'video': |
|
|
|
|
|
src_rel = data_info['original_video'] |
|
|
|
|
|
ground_rel = data_info.get('grounded_video', data_info.get('ground_video')) |
|
|
tgt_rel = data_info['edited_video'] |
|
|
|
|
|
if self.data_root is not None: |
|
|
src_path = os.path.join(self.data_root, src_rel) |
|
|
ground_path = os.path.join(self.data_root, ground_rel) |
|
|
tgt_path = os.path.join(self.data_root, tgt_rel) |
|
|
else: |
|
|
src_path = src_rel |
|
|
ground_path = ground_rel |
|
|
tgt_path = tgt_rel |
|
|
|
|
|
|
|
|
from decord import cpu |
|
|
with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \ |
|
|
VideoReader_contextmanager(ground_path, num_threads=2, ctx=cpu(0)) as ground_reader, \ |
|
|
VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader: |
|
|
|
|
|
|
|
|
src_length = len(src_reader) |
|
|
ground_length = len(ground_reader) |
|
|
tgt_length = len(tgt_reader) |
|
|
|
|
|
|
|
|
if src_length < self.source_frames: |
|
|
raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}") |
|
|
if tgt_length < self.target_frames: |
|
|
raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}") |
|
|
|
|
|
|
|
|
start_idx = 0 |
|
|
|
|
|
|
|
|
src_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1), |
|
|
self.source_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
|
|
|
tgt_indices = np.linspace( |
|
|
start_idx, |
|
|
min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1), |
|
|
self.target_frames, |
|
|
dtype=int |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices)) |
|
|
tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices)) |
|
|
|
|
|
if self.enable_gradual_ground: |
|
|
|
|
|
ground_first = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, [0])) |
|
|
|
|
|
tgt_first_frame = tgt_frames[0] |
|
|
|
|
|
ground_frames = self._interpolate_ground_frames( |
|
|
ground_first=ground_first[0], |
|
|
target_first=tgt_first_frame, |
|
|
total_steps=16, |
|
|
pick_indices=(0, 3, 6, 9, 12), |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ground_indices = src_indices[:self.reasoning_frames] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(ground_indices) > 0 and ground_indices[-1] >= ground_length: |
|
|
raise ValueError( |
|
|
f"Data inconsistency error: Ground video has only {ground_length} frames, " |
|
|
f"but the source-based sampling (stride={self.video_sample_stride}) " |
|
|
f"requires reading up to frame {ground_indices[-1]}. " |
|
|
f"File: {ground_path}" |
|
|
) |
|
|
ground_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, ground_indices)) |
|
|
except FunctionTimedOut: |
|
|
raise ValueError(f"Read {idx} timeout.") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to extract frames from triplet. Error is {e}.") |
|
|
|
|
|
|
|
|
sh, sw = src_frames.shape[1], src_frames.shape[2] |
|
|
gh, gw = ground_frames.shape[1], ground_frames.shape[2] |
|
|
th, tw = tgt_frames.shape[1], tgt_frames.shape[2] |
|
|
target_h = min(sh, gh, th) |
|
|
target_w = min(sw, gw, tw) |
|
|
|
|
|
if (sh != target_h or sw != target_w): |
|
|
src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w) |
|
|
if (gh != target_h or gw != target_w): |
|
|
ground_frames = self._resize_and_center_crop_batch(ground_frames, target_h, target_w) |
|
|
if (th != target_h or tw != target_w): |
|
|
tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w) |
|
|
|
|
|
if self.enable_gray_red_mask or self.enable_gray_black_background: |
|
|
effect_mode = "red" if self.enable_gray_red_mask else "black" |
|
|
ground_frames = self._apply_gray_region_effect(ground_frames, effect_mode) |
|
|
elif self.enable_gray_alpha_overlay: |
|
|
|
|
|
|
|
|
ground_frames = self._apply_gray_overlay_from_reference( |
|
|
src_frames, ground_frames, alpha=self.gray_alpha, gray_value=0.5, num_frames=4 |
|
|
) |
|
|
|
|
|
if not self.enable_bucket: |
|
|
src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
ground_tensor = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255. |
|
|
|
|
|
src_tensor = self.video_transforms(src_tensor) |
|
|
ground_tensor = self.video_transforms(ground_tensor) |
|
|
tgt_tensor = self.video_transforms(tgt_tensor) |
|
|
else: |
|
|
src_tensor = src_frames |
|
|
ground_tensor = ground_frames |
|
|
tgt_tensor = tgt_frames |
|
|
|
|
|
ground_instr = self._derive_ground_instruction(raw_text) |
|
|
if self.instruction_template and "{edit_instruction}" in self.instruction_template: |
|
|
text = self.instruction_template.format( |
|
|
edit_instruction=raw_text, |
|
|
ground_instruction=ground_instr |
|
|
) |
|
|
else: |
|
|
text = raw_text |
|
|
|
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
return src_tensor, ground_tensor, tgt_tensor, text, 'video' |
|
|
|
|
|
else: |
|
|
|
|
|
src_img_rel = data_info.get('original_image') |
|
|
tgt_img_rel = data_info.get('edited_image') |
|
|
if src_img_rel is None or tgt_img_rel is None: |
|
|
raise ValueError('Missing original_image/edited_image for image sample') |
|
|
|
|
|
if self.data_root is not None: |
|
|
src_img_path = os.path.join(self.data_root, src_img_rel) |
|
|
tgt_img_path = os.path.join(self.data_root, tgt_img_rel) |
|
|
else: |
|
|
src_img_path = src_img_rel |
|
|
tgt_img_path = tgt_img_rel |
|
|
|
|
|
src_img = Image.open(src_img_path).convert('RGB') |
|
|
tgt_img = Image.open(tgt_img_path).convert('RGB') |
|
|
|
|
|
if not self.enable_bucket: |
|
|
|
|
|
src_tensor = self.image_transforms(src_img).unsqueeze(0) |
|
|
tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) |
|
|
else: |
|
|
|
|
|
src_tensor = np.expand_dims(np.array(src_img), axis=0) |
|
|
tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) |
|
|
|
|
|
|
|
|
if self.instruction_template and "{edit_instruction}" in self.instruction_template: |
|
|
text = self.instruction_template.format(edit_instruction=raw_text, ground_instruction="") |
|
|
else: |
|
|
text = raw_text |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
|
|
|
return src_tensor, None, tgt_tensor, text, 'image' |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
data_type = data_info.get('type', 'video') |
|
|
while True: |
|
|
sample = {} |
|
|
try: |
|
|
data_info_local = self.dataset[idx % len(self.dataset)] |
|
|
data_type_local = data_info_local.get('type', 'video') |
|
|
if data_type_local != data_type: |
|
|
raise ValueError("data_type_local != data_type") |
|
|
|
|
|
result = self.get_batch(idx) |
|
|
|
|
|
if data_type == 'video': |
|
|
src_vals, ground_vals, tgt_vals, name, data_type = result |
|
|
sample["pixel_values_src_video"] = src_vals |
|
|
sample["pixel_values_ground_video"] = ground_vals |
|
|
sample["pixel_values_tgt_video"] = tgt_vals |
|
|
else: |
|
|
src_vals, _, tgt_vals, name, data_type = result |
|
|
sample["pixel_values_src_image"] = src_vals |
|
|
sample["pixel_values_tgt_image"] = tgt_vals |
|
|
|
|
|
sample["text"] = name |
|
|
sample["data_type"] = data_type |
|
|
sample["idx"] = idx |
|
|
|
|
|
if len(sample) > 0: |
|
|
break |
|
|
except Exception as e: |
|
|
print(e, self.dataset[idx % len(self.dataset)]) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
return sample |
|
|
|
|
|
def padding_image(images, new_width, new_height): |
|
|
new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) |
|
|
|
|
|
aspect_ratio = images.width / images.height |
|
|
if new_width / new_height > 1: |
|
|
if aspect_ratio > new_width / new_height: |
|
|
new_img_width = new_width |
|
|
new_img_height = int(new_img_width / aspect_ratio) |
|
|
else: |
|
|
new_img_height = new_height |
|
|
new_img_width = int(new_img_height * aspect_ratio) |
|
|
else: |
|
|
if aspect_ratio > new_width / new_height: |
|
|
new_img_width = new_width |
|
|
new_img_height = int(new_img_width / aspect_ratio) |
|
|
else: |
|
|
new_img_height = new_height |
|
|
new_img_width = int(new_img_height * aspect_ratio) |
|
|
|
|
|
resized_img = images.resize((new_img_width, new_img_height)) |
|
|
|
|
|
paste_x = (new_width - new_img_width) // 2 |
|
|
paste_y = (new_height - new_img_height) // 2 |
|
|
|
|
|
new_image.paste(resized_img, (paste_x, paste_y)) |
|
|
|
|
|
return new_image |
|
|
|
|
|
class ImageVideoControlDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
ann_path, data_root=None, |
|
|
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, |
|
|
image_sample_size=512, |
|
|
video_repeat=0, |
|
|
text_drop_ratio=0.1, |
|
|
enable_bucket=False, |
|
|
video_length_drop_start=0.1, |
|
|
video_length_drop_end=0.9, |
|
|
enable_inpaint=False, |
|
|
enable_camera_info=False, |
|
|
): |
|
|
|
|
|
if ann_path.endswith('.csv'): |
|
|
with open(ann_path, 'r') as csvfile: |
|
|
dataset = list(csv.DictReader(csvfile)) |
|
|
elif ann_path.endswith('.json'): |
|
|
dataset = json.load(open(ann_path)) |
|
|
|
|
|
self.data_root = data_root |
|
|
|
|
|
|
|
|
if video_repeat > 0: |
|
|
self.dataset = [] |
|
|
for data in dataset: |
|
|
if data.get('type', 'image') != 'video': |
|
|
self.dataset.append(data) |
|
|
|
|
|
for _ in range(video_repeat): |
|
|
for data in dataset: |
|
|
if data.get('type', 'image') == 'video': |
|
|
self.dataset.append(data) |
|
|
else: |
|
|
self.dataset = dataset |
|
|
del dataset |
|
|
|
|
|
self.length = len(self.dataset) |
|
|
print(f"data scale: {self.length}") |
|
|
|
|
|
self.enable_bucket = enable_bucket |
|
|
self.text_drop_ratio = text_drop_ratio |
|
|
self.enable_inpaint = enable_inpaint |
|
|
self.enable_camera_info = enable_camera_info |
|
|
|
|
|
self.video_length_drop_start = video_length_drop_start |
|
|
self.video_length_drop_end = video_length_drop_end |
|
|
|
|
|
|
|
|
self.video_sample_stride = video_sample_stride |
|
|
self.video_sample_n_frames = video_sample_n_frames |
|
|
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
|
|
self.video_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
] |
|
|
) |
|
|
if self.enable_camera_info: |
|
|
self.video_transforms_camera = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(min(self.video_sample_size)), |
|
|
transforms.CenterCrop(self.video_sample_size) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) |
|
|
self.image_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(self.image_sample_size)), |
|
|
transforms.CenterCrop(self.image_sample_size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) |
|
|
|
|
|
def get_batch(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
video_id, text = data_info['file_path'], data_info['text'] |
|
|
|
|
|
if data_info.get('type', 'image')=='video': |
|
|
if self.data_root is None: |
|
|
video_dir = video_id |
|
|
else: |
|
|
video_dir = os.path.join(self.data_root, video_id) |
|
|
|
|
|
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: |
|
|
min_sample_n_frames = min( |
|
|
self.video_sample_n_frames, |
|
|
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) |
|
|
) |
|
|
if min_sample_n_frames == 0: |
|
|
raise ValueError(f"No Frames in video.") |
|
|
|
|
|
video_length = int(self.video_length_drop_end * len(video_reader)) |
|
|
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) |
|
|
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 |
|
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) |
|
|
|
|
|
try: |
|
|
sample_args = (video_reader, batch_index) |
|
|
pixel_values = func_timeout( |
|
|
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
|
|
) |
|
|
resized_frames = [] |
|
|
for i in range(len(pixel_values)): |
|
|
frame = pixel_values[i] |
|
|
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
|
|
resized_frames.append(resized_frame) |
|
|
pixel_values = np.array(resized_frames) |
|
|
except FunctionTimedOut: |
|
|
raise ValueError(f"Read {idx} timeout.") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
|
|
|
|
|
if not self.enable_bucket: |
|
|
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() |
|
|
pixel_values = pixel_values / 255. |
|
|
del video_reader |
|
|
else: |
|
|
pixel_values = pixel_values |
|
|
|
|
|
if not self.enable_bucket: |
|
|
pixel_values = self.video_transforms(pixel_values) |
|
|
|
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
control_video_id = data_info['control_file_path'] |
|
|
|
|
|
if self.data_root is None: |
|
|
control_video_id = control_video_id |
|
|
else: |
|
|
control_video_id = os.path.join(self.data_root, control_video_id) |
|
|
|
|
|
if self.enable_camera_info: |
|
|
if control_video_id.lower().endswith('.txt'): |
|
|
if not self.enable_bucket: |
|
|
control_pixel_values = torch.zeros_like(pixel_values) |
|
|
|
|
|
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0]) |
|
|
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous() |
|
|
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True) |
|
|
control_camera_values = self.video_transforms_camera(control_camera_values) |
|
|
else: |
|
|
control_pixel_values = np.zeros_like(pixel_values) |
|
|
|
|
|
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True) |
|
|
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0) |
|
|
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0] |
|
|
control_camera_values = np.array([control_camera_values[index] for index in batch_index]) |
|
|
else: |
|
|
if not self.enable_bucket: |
|
|
control_pixel_values = torch.zeros_like(pixel_values) |
|
|
control_camera_values = None |
|
|
else: |
|
|
control_pixel_values = np.zeros_like(pixel_values) |
|
|
control_camera_values = None |
|
|
else: |
|
|
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: |
|
|
try: |
|
|
sample_args = (control_video_reader, batch_index) |
|
|
control_pixel_values = func_timeout( |
|
|
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
|
|
) |
|
|
resized_frames = [] |
|
|
for i in range(len(control_pixel_values)): |
|
|
frame = control_pixel_values[i] |
|
|
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
|
|
resized_frames.append(resized_frame) |
|
|
control_pixel_values = np.array(resized_frames) |
|
|
except FunctionTimedOut: |
|
|
raise ValueError(f"Read {idx} timeout.") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
|
|
|
|
|
if not self.enable_bucket: |
|
|
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() |
|
|
control_pixel_values = control_pixel_values / 255. |
|
|
del control_video_reader |
|
|
else: |
|
|
control_pixel_values = control_pixel_values |
|
|
|
|
|
if not self.enable_bucket: |
|
|
control_pixel_values = self.video_transforms(control_pixel_values) |
|
|
control_camera_values = None |
|
|
|
|
|
return pixel_values, control_pixel_values, control_camera_values, text, "video" |
|
|
else: |
|
|
image_path, text = data_info['file_path'], data_info['text'] |
|
|
if self.data_root is not None: |
|
|
image_path = os.path.join(self.data_root, image_path) |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
if not self.enable_bucket: |
|
|
image = self.image_transforms(image).unsqueeze(0) |
|
|
else: |
|
|
image = np.expand_dims(np.array(image), 0) |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
|
text = '' |
|
|
|
|
|
control_image_id = data_info['control_file_path'] |
|
|
|
|
|
if self.image_root is None: |
|
|
control_image_id = control_image_id |
|
|
else: |
|
|
control_image_id = os.path.join(self.image_root, control_image_id) |
|
|
|
|
|
control_image = Image.open(control_image_id).convert('RGB') |
|
|
if not self.enable_bucket: |
|
|
control_image = self.image_transforms(control_image).unsqueeze(0) |
|
|
else: |
|
|
control_image = np.expand_dims(np.array(control_image), 0) |
|
|
return image, control_image, None, text, 'image' |
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
data_type = data_info.get('type', 'image') |
|
|
while True: |
|
|
sample = {} |
|
|
try: |
|
|
data_info_local = self.dataset[idx % len(self.dataset)] |
|
|
data_type_local = data_info_local.get('type', 'image') |
|
|
if data_type_local != data_type: |
|
|
raise ValueError("data_type_local != data_type") |
|
|
|
|
|
pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx) |
|
|
|
|
|
sample["pixel_values"] = pixel_values |
|
|
sample["control_pixel_values"] = control_pixel_values |
|
|
sample["text"] = name |
|
|
sample["data_type"] = data_type |
|
|
sample["idx"] = idx |
|
|
|
|
|
if self.enable_camera_info: |
|
|
sample["control_camera_values"] = control_camera_values |
|
|
|
|
|
if len(sample) > 0: |
|
|
break |
|
|
except Exception as e: |
|
|
print(e, self.dataset[idx % len(self.dataset)]) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
if self.enable_inpaint and not self.enable_bucket: |
|
|
mask = get_random_mask(pixel_values.size()) |
|
|
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask |
|
|
sample["mask_pixel_values"] = mask_pixel_values |
|
|
sample["mask"] = mask |
|
|
|
|
|
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() |
|
|
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 |
|
|
sample["clip_pixel_values"] = clip_pixel_values |
|
|
|
|
|
return sample |