Spaces:
Runtime error
Runtime error
| import os | |
| import h5py | |
| import numpy as np | |
| import torch | |
| from unik3d.datasets.image_dataset import ImageDataset | |
| from unik3d.datasets.pipelines import AnnotationMask, Compose, KittiCrop | |
| from unik3d.datasets.sequence_dataset import SequenceDataset | |
| from unik3d.datasets.utils import DatasetFromList | |
| from unik3d.utils import identity | |
| class KITTI(ImageDataset): | |
| CAM_INTRINSIC = { | |
| "2011_09_26": torch.tensor( | |
| [ | |
| [7.215377e02, 0.000000e00, 6.095593e02, 4.485728e01], | |
| [0.000000e00, 7.215377e02, 1.728540e02, 2.163791e-01], | |
| [0.000000e00, 0.000000e00, 1.000000e00, 2.745884e-03], | |
| ] | |
| ), | |
| "2011_09_28": torch.tensor( | |
| [ | |
| [7.070493e02, 0.000000e00, 6.040814e02, 4.575831e01], | |
| [0.000000e00, 7.070493e02, 1.805066e02, -3.454157e-01], | |
| [0.000000e00, 0.000000e00, 1.000000e00, 4.981016e-03], | |
| ] | |
| ), | |
| "2011_09_29": torch.tensor( | |
| [ | |
| [7.183351e02, 0.000000e00, 6.003891e02, 4.450382e01], | |
| [0.000000e00, 7.183351e02, 1.815122e02, -5.951107e-01], | |
| [0.000000e00, 0.000000e00, 1.000000e00, 2.616315e-03], | |
| ] | |
| ), | |
| "2011_09_30": torch.tensor( | |
| [ | |
| [7.070912e02, 0.000000e00, 6.018873e02, 4.688783e01], | |
| [0.000000e00, 7.070912e02, 1.831104e02, 1.178601e-01], | |
| [0.000000e00, 0.000000e00, 1.000000e00, 6.203223e-03], | |
| ] | |
| ), | |
| "2011_10_03": torch.tensor( | |
| [ | |
| [7.188560e02, 0.000000e00, 6.071928e02, 4.538225e01], | |
| [0.000000e00, 7.188560e02, 1.852157e02, -1.130887e-01], | |
| [0.000000e00, 0.000000e00, 1.000000e00, 3.779761e-03], | |
| ] | |
| ), | |
| } | |
| min_depth = 0.05 | |
| max_depth = 80.0 | |
| depth_scale = 256.0 | |
| log_mean = 2.5462 | |
| log_std = 0.5871 | |
| test_split = "kitti_eigen_test.txt" | |
| train_split = "kitti_eigen_train.txt" | |
| test_split_benchmark = "kitti_test.txt" | |
| hdf5_paths = ["kitti.hdf5"] | |
| def __init__( | |
| self, | |
| image_shape, | |
| split_file, | |
| test_mode, | |
| crop=None, | |
| benchmark=False, | |
| augmentations_db={}, | |
| normalize=True, | |
| resize_method="hard", | |
| mini=1.0, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| image_shape=image_shape, | |
| split_file=split_file, | |
| test_mode=test_mode, | |
| benchmark=benchmark, | |
| normalize=normalize, | |
| augmentations_db=augmentations_db, | |
| resize_method=resize_method, | |
| mini=mini, | |
| **kwargs, | |
| ) | |
| self.masker = AnnotationMask( | |
| min_value=0.0, | |
| max_value=self.max_depth if test_mode else None, | |
| custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, | |
| ) | |
| self.test_mode = test_mode | |
| self.crop = crop | |
| self.cropper_base = KittiCrop(crop_size=(352, 1216)) | |
| self.load_dataset() | |
| def load_dataset(self): | |
| h5file = h5py.File( | |
| os.path.join(self.data_root, self.hdf5_paths[0]), | |
| "r", | |
| libver="latest", | |
| swmr=True, | |
| ) | |
| txt_file = np.array(h5file[self.split_file]) | |
| txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 | |
| h5file.close() | |
| dataset = [] | |
| for line in txt_string.split("\n"): | |
| image_filename = line.strip().split(" ")[0] | |
| depth_filename = line.strip().split(" ")[1] | |
| if depth_filename == "None": | |
| continue | |
| sample = [ | |
| image_filename, | |
| depth_filename, | |
| ] | |
| dataset.append(sample) | |
| if not self.test_mode: | |
| dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) | |
| self.dataset = DatasetFromList(dataset) | |
| self.log_load_dataset() | |
| def get_intrinsics(self, idx, image_name): | |
| return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone() | |
| def preprocess(self, results): | |
| results = self.replicate(results) | |
| for i, seq in enumerate(results["sequence_fields"]): | |
| self.resizer.ctx = None | |
| results[seq] = self.cropper_base(results[seq]) | |
| results[seq] = self.resizer(results[seq]) | |
| num_pts = torch.count_nonzero(results[seq]["depth"] > 0) | |
| if num_pts < 50: | |
| raise IndexError(f"Too few points in depth map ({num_pts})") | |
| for key in results[seq].get("image_fields", ["image"]): | |
| results[seq][key] = results[seq][key].to(torch.float32) / 255 | |
| # update fields common in sequence | |
| for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: | |
| if key in results[(0, 0)]: | |
| results[key] = results[(0, 0)][key] | |
| results = self.pack_batch(results) | |
| return results | |
| def eval_mask(self, valid_mask, info={}): | |
| """Do grag_crop or eigen_crop for testing""" | |
| mask_height, mask_width = valid_mask.shape[-2:] | |
| eval_mask = torch.zeros_like(valid_mask) | |
| if "garg" in self.crop: | |
| eval_mask[ | |
| ..., | |
| int(0.40810811 * mask_height) : int(0.99189189 * mask_height), | |
| int(0.03594771 * mask_width) : int(0.96405229 * mask_width), | |
| ] = 1 | |
| elif "eigen" in self.crop: | |
| eval_mask[ | |
| ..., | |
| int(0.3324324 * mask_height) : int(0.91351351 * mask_height), | |
| int(0.03594771 * mask_width) : int(0.96405229 * mask_width), | |
| ] = 1 | |
| return torch.logical_and(valid_mask, eval_mask) | |
| def get_mapper(self): | |
| return { | |
| "image_filename": 0, | |
| "depth_filename": 1, | |
| } | |
| def pre_pipeline(self, results): | |
| results = super().pre_pipeline(results) | |
| results["dense"] = [False] * self.num_copies | |
| results["quality"] = [1] * self.num_copies | |
| return results | |
| import json | |
| class KITTIBenchmark(ImageDataset): | |
| min_depth = 0.05 | |
| max_depth = 80.0 | |
| depth_scale = 256.0 | |
| test_split = "test_split.txt" | |
| train_split = "val_split.txt" | |
| intrinsics_file = "intrinsics.json" | |
| hdf5_paths = ["kitti_benchmark.hdf5"] | |
| def __init__( | |
| self, | |
| image_shape, | |
| split_file, | |
| test_mode, | |
| crop=None, | |
| benchmark=False, | |
| augmentations_db={}, | |
| normalize=True, | |
| resize_method="hard", | |
| mini=1.0, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| image_shape=image_shape, | |
| split_file=split_file, | |
| test_mode=test_mode, | |
| benchmark=True, | |
| normalize=normalize, | |
| augmentations_db=augmentations_db, | |
| resize_method=resize_method, | |
| mini=mini, | |
| **kwargs, | |
| ) | |
| self.test_mode = test_mode | |
| self.crop = crop | |
| self.masker = AnnotationMask( | |
| min_value=self.min_depth, | |
| max_value=self.max_depth if test_mode else None, | |
| custom_fn=lambda x, *args, **kwargs: x, | |
| ) | |
| self.collecter = Collect(keys=["image_fields", "mask_fields", "gt_fields"]) | |
| self.load_dataset() | |
| def load_dataset(self): | |
| h5file = h5py.File( | |
| os.path.join(self.data_root, self.hdf5_path), | |
| "r", | |
| libver="latest", | |
| swmr=True, | |
| ) | |
| txt_file = np.array(self.h5file[self.split_file]) | |
| txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 | |
| intrinsics = np.array(h5file[self.intrinsics_file]).tostring().decode("ascii") | |
| intrinsics = json.loads(intrinsics) | |
| h5file.close() | |
| dataset = [] | |
| for line in txt_string.split("\n"): | |
| image_filename, depth_filename = line.strip().split(" ") | |
| intrinsics = torch.tensor( | |
| intrinsics[os.path.join(*image_filename.split("/")[:2])] | |
| ).squeeze()[:, :3] | |
| sample = { | |
| "image_filename": image_filename, | |
| "depth_filename": depth_filename, | |
| "K": intrinsics, | |
| } | |
| dataset.append(sample) | |
| self.dataset = DatasetFromList(dataset) | |
| self.log_load_dataset() | |
| class KITTIRMVD(SequenceDataset): | |
| min_depth = 0.05 | |
| max_depth = 80.0 | |
| depth_scale = 256.0 | |
| default_fps = 10 | |
| test_split = "test.txt" | |
| train_split = "test.txt" | |
| sequences_file = "sequences.json" | |
| hdf5_paths = ["kitti_rmvd.hdf5"] | |
| def __init__( | |
| self, | |
| image_shape, | |
| split_file, | |
| test_mode, | |
| crop=None, | |
| augmentations_db={}, | |
| normalize=True, | |
| resize_method="hard", | |
| mini: float = 1.0, | |
| num_frames: int = 1, | |
| benchmark: bool = False, | |
| decode_fields: list[str] = ["image", "depth"], | |
| inplace_fields: list[str] = ["K", "cam2w"], | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| image_shape=image_shape, | |
| split_file=split_file, | |
| test_mode=test_mode, | |
| benchmark=benchmark, | |
| normalize=normalize, | |
| augmentations_db=augmentations_db, | |
| resize_method=resize_method, | |
| mini=mini, | |
| num_frames=num_frames, | |
| decode_fields=decode_fields, | |
| inplace_fields=inplace_fields, | |
| **kwargs, | |
| ) | |
| self.crop = crop | |
| self.resizer = Compose([KittiCrop(crop_size=(352, 1216)), self.resizer]) | |
| def eval_mask(self, valid_mask, info={}): | |
| """Do grag_crop or eigen_crop for testing""" | |
| mask_height, mask_width = valid_mask.shape[-2:] | |
| eval_mask = torch.zeros_like(valid_mask) | |
| if "garg" in self.crop: | |
| eval_mask[ | |
| ..., | |
| int(0.40810811 * mask_height) : int(0.99189189 * mask_height), | |
| int(0.03594771 * mask_width) : int(0.96405229 * mask_width), | |
| ] = 1 | |
| elif "eigen" in self.crop: | |
| eval_mask[ | |
| ..., | |
| int(0.3324324 * mask_height) : int(0.91351351 * mask_height), | |
| int(0.03594771 * mask_width) : int(0.96405229 * mask_width), | |
| ] = 1 | |
| else: | |
| return valid_mask | |
| return torch.logical_and(valid_mask, eval_mask) | |