Spaces:
Runtime error
Runtime error
| from collections.abc import Sequence | |
| import numpy as np | |
| import torch | |
| class Collect(object): | |
| def __init__( | |
| self, | |
| keys, | |
| meta_keys=( | |
| "filename", | |
| "keyframe_idx", | |
| "sequence_name", | |
| "image_filename", | |
| "depth_filename", | |
| "image_ori_shape", | |
| "camera", | |
| "original_camera", | |
| "sfm", | |
| "image_shape", | |
| "resized_shape", | |
| "scale_factor", | |
| "rotation", | |
| "resize_factor", | |
| "flip", | |
| "flip_direction", | |
| "dataset_name", | |
| "paddings", | |
| "max_value", | |
| "log_mean", | |
| "log_std", | |
| "image_rescale", | |
| "focal_rescale", | |
| "depth_rescale", | |
| ), | |
| ): | |
| self.keys = keys | |
| self.meta_keys = meta_keys | |
| def __call__(self, results): | |
| data_keys = [key for field in self.keys for key in results.get(field, [])] | |
| data = { | |
| key: { | |
| sequence_key: results[key][sequence_key] | |
| for sequence_key in results["sequence_fields"] | |
| } | |
| for key in data_keys | |
| } | |
| data["img_metas"] = { | |
| key: value for key, value in results.items() if key not in data_keys | |
| } | |
| return data | |
| def __repr__(self): | |
| return ( | |
| self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" | |
| ) | |
| class AnnotationMask(object): | |
| def __init__(self, min_value, max_value, custom_fn=lambda x: x): | |
| self.min_value = min_value | |
| self.max_value = max_value | |
| self.custom_fn = custom_fn | |
| def __call__(self, results): | |
| for key in results.get("gt_fields", []): | |
| if key + "_mask" in results["mask_fields"]: | |
| if "flow" in key: | |
| for sequence_idx in results.get("sequence_fields", []): | |
| boundaries = (results[key][sequence_idx] >= -1) & ( | |
| results[key][sequence_idx] <= 1 | |
| ) | |
| boundaries = boundaries[:, :1] & boundaries[:, 1:] | |
| results[key + "_mask"][sequence_idx] = ( | |
| results[key + "_mask"][sequence_idx].bool() & boundaries | |
| ) | |
| continue | |
| for sequence_idx in results.get("sequence_fields", []): | |
| # take care of xyz or flow, dim=1 is the channel dim | |
| if results[key][sequence_idx].shape[1] == 1: | |
| mask = results[key][sequence_idx] > self.min_value | |
| else: | |
| mask = ( | |
| results[key][sequence_idx].norm(dim=1, keepdim=True) | |
| > self.min_value | |
| ) | |
| if self.max_value is not None: | |
| if results[key][sequence_idx].shape[1] == 1: | |
| mask = mask & (results[key][sequence_idx] < self.max_value) | |
| else: | |
| mask = mask & ( | |
| results[key][sequence_idx].norm(dim=1, keepdim=True) | |
| < self.max_value | |
| ) | |
| mask = self.custom_fn(mask, info=results) | |
| if key + "_mask" not in results: | |
| results[key + "_mask"] = {} | |
| if sequence_idx not in results[key + "_mask"]: | |
| results[key + "_mask"][sequence_idx] = mask.bool() | |
| else: | |
| results[key + "_mask"][sequence_idx] = ( | |
| results[key + "_mask"][sequence_idx].bool() & mask.bool() | |
| ) | |
| results["mask_fields"].add(key + "_mask") | |
| return results | |
| def __repr__(self): | |
| return ( | |
| self.__class__.__name__ | |
| + f"(min_value={self.min_value}, max_value={ self.max_value})" | |
| ) | |