Spaces:
Runtime error
Runtime error
| # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
| # More information about the method can be found at https://marigoldmonodepth.github.io | |
| # -------------------------------------------------------------------------- | |
| from functools import partial | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from .image_util import get_tv_resample_method, resize_max_res | |
| def inter_distances(tensors: torch.Tensor): | |
| """ | |
| To calculate the distance between each two depth maps. | |
| """ | |
| distances = [] | |
| for i, j in torch.combinations(torch.arange(tensors.shape[0])): | |
| arr1 = tensors[i : i + 1] | |
| arr2 = tensors[j : j + 1] | |
| distances.append(arr1 - arr2) | |
| dist = torch.concatenate(distances, dim=0) | |
| return dist | |
| def ensemble_depth( | |
| depth: torch.Tensor, | |
| scale_invariant: bool = True, | |
| shift_invariant: bool = True, | |
| output_uncertainty: bool = False, | |
| reduction: str = "median", | |
| regularizer_strength: float = 0.02, | |
| max_iter: int = 2, | |
| tol: float = 1e-3, | |
| max_res: int = 1024, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the | |
| number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for | |
| depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The | |
| alignment happens when the predictions have one or more degrees of freedom, that is when they are either | |
| affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only | |
| `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) | |
| alignment is skipped and only ensembling is performed. | |
| Args: | |
| depth (`torch.Tensor`): | |
| Input ensemble depth maps. | |
| scale_invariant (`bool`, *optional*, defaults to `True`): | |
| Whether to treat predictions as scale-invariant. | |
| shift_invariant (`bool`, *optional*, defaults to `True`): | |
| Whether to treat predictions as shift-invariant. | |
| output_uncertainty (`bool`, *optional*, defaults to `False`): | |
| Whether to output uncertainty map. | |
| reduction (`str`, *optional*, defaults to `"median"`): | |
| Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and | |
| `"median"`. | |
| regularizer_strength (`float`, *optional*, defaults to `0.02`): | |
| Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. | |
| max_iter (`int`, *optional*, defaults to `2`): | |
| Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` | |
| argument. | |
| tol (`float`, *optional*, defaults to `1e-3`): | |
| Alignment solver tolerance. The solver stops when the tolerance is reached. | |
| max_res (`int`, *optional*, defaults to `1024`): | |
| Resolution at which the alignment is performed; `None` matches the `processing_resolution`. | |
| Returns: | |
| A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: | |
| `(1, 1, H, W)`. | |
| """ | |
| if depth.dim() != 4 or depth.shape[1] != 1: | |
| raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") | |
| if reduction not in ("mean", "median"): | |
| raise ValueError(f"Unrecognized reduction method: {reduction}.") | |
| if not scale_invariant and shift_invariant: | |
| raise ValueError("Pure shift-invariant ensembling is not supported.") | |
| def init_param(depth: torch.Tensor): | |
| init_min = depth.reshape(ensemble_size, -1).min(dim=1).values | |
| init_max = depth.reshape(ensemble_size, -1).max(dim=1).values | |
| if scale_invariant and shift_invariant: | |
| init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) | |
| init_t = -init_s * init_min | |
| param = torch.cat((init_s, init_t)).cpu().numpy() | |
| elif scale_invariant: | |
| init_s = 1.0 / init_max.clamp(min=1e-6) | |
| param = init_s.cpu().numpy() | |
| else: | |
| raise ValueError("Unrecognized alignment.") | |
| return param | |
| def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: | |
| if scale_invariant and shift_invariant: | |
| s, t = np.split(param, 2) | |
| s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) | |
| t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) | |
| out = depth * s + t | |
| elif scale_invariant: | |
| s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) | |
| out = depth * s | |
| else: | |
| raise ValueError("Unrecognized alignment.") | |
| return out | |
| def ensemble( | |
| depth_aligned: torch.Tensor, return_uncertainty: bool = False | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| uncertainty = None | |
| if reduction == "mean": | |
| prediction = torch.mean(depth_aligned, dim=0, keepdim=True) | |
| if return_uncertainty: | |
| uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) | |
| elif reduction == "median": | |
| prediction = torch.median(depth_aligned, dim=0, keepdim=True).values | |
| if return_uncertainty: | |
| uncertainty = torch.median( | |
| torch.abs(depth_aligned - prediction), dim=0, keepdim=True | |
| ).values | |
| else: | |
| raise ValueError(f"Unrecognized reduction method: {reduction}.") | |
| return prediction, uncertainty | |
| def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: | |
| cost = 0.0 | |
| depth_aligned = align(depth, param) | |
| for i, j in torch.combinations(torch.arange(ensemble_size)): | |
| diff = depth_aligned[i] - depth_aligned[j] | |
| cost += (diff**2).mean().sqrt().item() | |
| if regularizer_strength > 0: | |
| prediction, _ = ensemble(depth_aligned, return_uncertainty=False) | |
| err_near = (0.0 - prediction.min()).abs().item() | |
| err_far = (1.0 - prediction.max()).abs().item() | |
| cost += (err_near + err_far) * regularizer_strength | |
| return cost | |
| def compute_param(depth: torch.Tensor): | |
| import scipy | |
| depth_to_align = depth.to(torch.float32) | |
| if max_res is not None and max(depth_to_align.shape[2:]) > max_res: | |
| try: | |
| depth_to_align = resize_max_res( | |
| depth_to_align, max_res, get_tv_resample_method("nearest-exact") | |
| ) | |
| except: | |
| depth_to_align = resize_max_res( | |
| depth_to_align, max_res, get_tv_resample_method("bilinear") | |
| ) | |
| param = init_param(depth_to_align) | |
| res = scipy.optimize.minimize( | |
| partial(cost_fn, depth=depth_to_align), | |
| param, | |
| method="BFGS", | |
| tol=tol, | |
| options={"maxiter": max_iter, "disp": False}, | |
| ) | |
| return res.x | |
| requires_aligning = scale_invariant or shift_invariant | |
| ensemble_size = depth.shape[0] | |
| if requires_aligning: | |
| param = compute_param(depth) | |
| depth = align(depth, param) | |
| depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) | |
| depth_max = depth.max() | |
| if scale_invariant and shift_invariant: | |
| depth_min = depth.min() | |
| elif scale_invariant: | |
| depth_min = 0 | |
| else: | |
| raise ValueError("Unrecognized alignment.") | |
| depth_range = (depth_max - depth_min).clamp(min=1e-6) | |
| depth = (depth - depth_min) / depth_range | |
| if output_uncertainty: | |
| uncertainty /= depth_range | |
| return depth, uncertainty # [1,1,H,W], [1,1,H,W] | |