Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| from functools import partial | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.v2.functional as TF | |
| from PIL import Image | |
| from unik3d.utils.chamfer_distance import ChamferDistance | |
| from unik3d.utils.constants import DEPTH_BINS | |
| chamfer_cls = ChamferDistance() | |
| def kl_div(gt, pred, eps: float = 1e-6): | |
| depth_bins = DEPTH_BINS.to(gt.device) | |
| gt, pred = torch.bucketize( | |
| gt, boundaries=depth_bins, out_int32=True | |
| ), torch.bucketize(pred, boundaries=depth_bins, out_int32=True) | |
| gt = torch.bincount(gt, minlength=len(depth_bins) + 1) | |
| pred = torch.bincount(pred, minlength=len(depth_bins) + 1) | |
| gt = gt / gt.sum() | |
| pred = pred / pred.sum() | |
| return torch.sum(gt * (torch.log(gt + eps) - torch.log(pred + eps))) | |
| def chamfer_dist(tensor1, tensor2): | |
| x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) | |
| y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) | |
| dist1, dist2, idx1, idx2 = chamfer_cls( | |
| tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths | |
| ) | |
| return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 | |
| def auc(tensor1, tensor2, thresholds): | |
| x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) | |
| y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) | |
| dist1, dist2, idx1, idx2 = chamfer_cls( | |
| tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths | |
| ) | |
| # compute precision recall | |
| precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] | |
| recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] | |
| auc_value = torch.trapz( | |
| torch.tensor(precisions, device=tensor1.device), | |
| torch.tensor(recalls, device=tensor1.device), | |
| ) | |
| return auc_value | |
| def delta(tensor1, tensor2, exponent): | |
| inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) | |
| return (inlier < 1.25**exponent).to(torch.float32).mean() | |
| def rho(tensor1, tensor2): | |
| min_deg = 0.5 | |
| tensor1_norm = tensor1 / torch.norm(tensor1, dim=-1, p=2, keepdim=True).clip( | |
| min=1e-6 | |
| ) | |
| tensor2_norm = tensor2 / torch.norm(tensor2, dim=-1, p=2, keepdim=True).clip( | |
| min=1e-6 | |
| ) | |
| max_polar_angle = torch.arccos(tensor1_norm[..., 2]).max() * 180.0 / torch.pi | |
| if max_polar_angle < 100.0: | |
| threshold = 15.0 | |
| elif max_polar_angle < 190.0: | |
| threshold = 20.0 | |
| else: | |
| threshold = 30.0 | |
| acos_clip = 1 - 1e-6 | |
| # inner prod of norm vector -> cosine | |
| angular_error = ( | |
| torch.arccos( | |
| (tensor1_norm * tensor2_norm) | |
| .sum(dim=-1) | |
| .clip(min=-acos_clip, max=acos_clip) | |
| ) | |
| * 180.0 | |
| / torch.pi | |
| ) | |
| thresholds = torch.linspace(min_deg, threshold, steps=100, device=tensor1.device) | |
| y_values = [ | |
| (angular_error.abs() <= th).to(torch.float32).mean() for th in thresholds | |
| ] | |
| auc_value = torch.trapz( | |
| torch.tensor(y_values, device=tensor1.device), thresholds | |
| ) / (threshold - min_deg) | |
| return auc_value | |
| def tau(tensor1, tensor2, perc): | |
| inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) | |
| return (inlier < (1.0 + perc)).to(torch.float32).mean() | |
| def ssi(tensor1, tensor2, qtl=0.05): | |
| stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) | |
| error = (tensor1 - tensor2).abs() | |
| mask = error < torch.quantile(error, 1 - qtl) | |
| tensor1_mask = tensor1.to(torch.float32)[mask] | |
| tensor2_mask = tensor2.to(torch.float32)[mask] | |
| stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) | |
| tensor2_one = torch.stack([tensor2_mask, torch.ones_like(tensor2_mask)], dim=1) | |
| A = torch.matmul(tensor2_one.T, tensor2_one) + stability_mat | |
| det_A = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] | |
| A_inv = (1.0 / det_A) * torch.tensor( | |
| [[A[1, 1], -A[0, 1]], [-A[1, 0], A[0, 0]]], device=tensor1.device | |
| ) | |
| b = tensor2_one.T @ tensor1_mask.unsqueeze(1) | |
| scale_shift = A_inv @ b | |
| scale, shift = scale_shift.squeeze().chunk(2, dim=0) | |
| return tensor2 * scale + shift | |
| def si(tensor1, tensor2): | |
| return tensor2 * torch.median(tensor1) / torch.median(tensor2) | |
| def arel(tensor1, tensor2): | |
| tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2) | |
| return (torch.abs(tensor1 - tensor2) / tensor1).mean() | |
| def d_auc(tensor1, tensor2): | |
| exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device) | |
| deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents] | |
| return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 | |
| def f1_score(tensor1, tensor2, thresholds): | |
| x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) | |
| y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) | |
| dist1, dist2, idx1, idx2 = chamfer_cls( | |
| tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths | |
| ) | |
| # compute precision recall | |
| precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] | |
| recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] | |
| precisions = torch.tensor(precisions, device=tensor1.device) | |
| recalls = torch.tensor(recalls, device=tensor1.device) | |
| f1_thresholds = 2 * precisions * recalls / (precisions + recalls) | |
| f1_thresholds = torch.where( | |
| torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds | |
| ) | |
| f1_value = torch.trapz(f1_thresholds) / len(thresholds) | |
| return f1_value | |
| def f1_score_si(tensor1, tensor2, thresholds): | |
| tensor2 = ( | |
| tensor2 | |
| * torch.median(tensor1.norm(dim=-1)) | |
| / torch.median(tensor2.norm(dim=-1)) | |
| ) | |
| f1_value = f1_score(tensor1, tensor2, thresholds) | |
| return f1_value | |
| DICT_METRICS = { | |
| "d1": partial(delta, exponent=1.0), | |
| "d2": partial(delta, exponent=2.0), | |
| "d3": partial(delta, exponent=3.0), | |
| "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()), | |
| "rmselog": lambda gt, pred: torch.sqrt( | |
| ((torch.log(gt) - torch.log(pred)) ** 2).mean() | |
| ), | |
| "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(), | |
| "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(), | |
| "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(), | |
| "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(), | |
| "medianlog": lambda gt, pred: 100 | |
| * (torch.log(pred) - torch.log(gt)).median().abs(), | |
| "d_auc": d_auc, | |
| "tau": partial(tau, perc=0.03), | |
| } | |
| DICT_METRICS_3D = { | |
| "MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2), | |
| "arel_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2) | |
| / torch.norm(gt, dim=0, p=2), | |
| "tau_3d": lambda gt, pred, thresholds: ( | |
| (torch.norm(pred, dim=0, p=2) / torch.norm(gt, dim=0, p=2)).log().abs().exp() | |
| < 1.25 | |
| ) | |
| .float() | |
| .mean(), | |
| "chamfer": lambda gt, pred, thresholds: chamfer_dist( | |
| gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) | |
| ), | |
| "F1": lambda gt, pred, thresholds: f1_score( | |
| gt.unsqueeze(0).permute(0, 2, 1), | |
| pred.unsqueeze(0).permute(0, 2, 1), | |
| thresholds=thresholds, | |
| ), | |
| "F1_si": lambda gt, pred, thresholds: f1_score_si( | |
| gt.unsqueeze(0).permute(0, 2, 1), | |
| pred.unsqueeze(0).permute(0, 2, 1), | |
| thresholds=thresholds, | |
| ), | |
| "rays": lambda gt, pred, thresholds: rho( | |
| gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) | |
| ), | |
| } | |
| DICT_METRICS_FLOW = { | |
| "epe": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)), | |
| "epe1": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 1, | |
| "epe3": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 3, | |
| "epe5": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 5, | |
| } | |
| DICT_METRICS_D = { | |
| "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( | |
| torch.float32 | |
| ), | |
| "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt), | |
| } | |
| def eval_depth( | |
| gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None | |
| ): | |
| summary_metrics = defaultdict(list) | |
| # preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") | |
| for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): | |
| if max_depth is not None: | |
| mask = mask & (gt <= max_depth) | |
| for name, fn in DICT_METRICS.items(): | |
| if name in ["tau", "d1", "arel"]: | |
| for rescale_fn in ["ssi", "si"]: | |
| summary_metrics[f"{name}_{rescale_fn}"].append( | |
| fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask])) | |
| ) | |
| summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) | |
| return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} | |
| def eval_3d( | |
| gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None | |
| ): | |
| summary_metrics = defaultdict(list) | |
| MAX_PIXELS = 75_000 # 300_000 | |
| ratio = min(1.0, (MAX_PIXELS / masks[0].sum()) ** 0.5) | |
| h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio) | |
| gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact") | |
| preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact") | |
| masks = F.interpolate( | |
| masks.float(), size=(h_max, w_max), mode="nearest-exact" | |
| ).bool() | |
| for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): | |
| if not torch.any(mask): | |
| continue | |
| for name, fn in DICT_METRICS_3D.items(): | |
| summary_metrics[name].append( | |
| fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() | |
| ) | |
| return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} | |
| def compute_aucs(gt, pred, mask, uncertainties, steps=50, metrics=["abs_rel"]): | |
| dict_ = {} | |
| x_axis = torch.linspace(0, 1, steps=steps + 1, device=gt.device) | |
| quantiles = torch.linspace(0, 1 - 1 / steps, steps=steps, device=gt.device) | |
| zer = torch.tensor(0.0, device=gt.device) | |
| # revert order (high uncertainty first) | |
| uncertainties = -uncertainties[mask] | |
| gt = gt[mask] | |
| pred = pred[mask] | |
| true_uncert = {metric: -DICT_METRICS_D[metric](gt, pred) for metric in metrics} | |
| # get percentiles for sampling and corresponding subsets | |
| thresholds = torch.quantile(uncertainties, quantiles) | |
| subs = [(uncertainties >= t) for t in thresholds] | |
| # compute sparsification curves for each metric (add 0 for final sampling) | |
| for metric in metrics: | |
| opt_thresholds = torch.quantile(true_uncert[metric], quantiles) | |
| opt_subs = [(true_uncert[metric] >= t) for t in opt_thresholds] | |
| sparse_curve = torch.stack( | |
| [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in subs] + [zer], dim=0 | |
| ) | |
| opt_curve = torch.stack( | |
| [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in opt_subs] + [zer], | |
| dim=0, | |
| ) | |
| rnd_curve = DICT_METRICS[metric](gt, pred) | |
| dict_[f"AUSE_{metric}"] = torch.trapz(sparse_curve - opt_curve, x=x_axis) | |
| dict_[f"AURG_{metric}"] = rnd_curve - torch.trapz(sparse_curve, x=x_axis) | |
| return dict_ | |
| def eval_depth_uncertainties( | |
| gts: torch.Tensor, | |
| preds: torch.Tensor, | |
| uncertainties: torch.Tensor, | |
| masks: torch.Tensor, | |
| max_depth=None, | |
| ): | |
| summary_metrics = defaultdict(list) | |
| preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") | |
| for i, (gt, pred, mask, uncertainty) in enumerate( | |
| zip(gts, preds, masks, uncertainties) | |
| ): | |
| if max_depth is not None: | |
| mask = torch.logical_and(mask, gt < max_depth) | |
| for name, fn in DICT_METRICS.items(): | |
| summary_metrics[name].append(fn(gt[mask], pred[mask])) | |
| for name, val in compute_aucs(gt, pred, mask, uncertainty).items(): | |
| summary_metrics[name].append(val) | |
| return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} | |
| def lazy_eval_depth( | |
| gt_fns, pred_fns, min_depth=1e-2, max_depth=None, depth_scale=256.0 | |
| ): | |
| summary_metrics = defaultdict(list) | |
| for i, (gt_fn, pred_fn) in enumerate(zip(gt_fns, pred_fns)): | |
| gt = TF.pil_to_tensor(Image.open(gt_fn)).to(torch.float32) / depth_scale | |
| pred = TF.pil_to_tensor(Image.open(pred_fn)).to(torch.float32) / depth_scale | |
| mask = gt > min_depth | |
| if max_depth is not None: | |
| mask_2 = gt < max_depth | |
| mask = torch.logical_and(mask, mask_2) | |
| for name, fn in DICT_METRICS.items(): | |
| summary_metrics[name].append(fn(gt[mask], pred[mask])) | |
| return {name: torch.mean(vals).item() for name, vals in summary_metrics.items()} | |