Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn, Tensor | |
| import torch.nn.functional as F | |
| from typing import Tuple, Any, Optional, Union | |
| from types import FunctionType | |
| from itertools import repeat | |
| from collections.abc import Iterable | |
| def _log_api_usage_once(obj: Any) -> None: | |
| """ | |
| Logs API usage(module and name) within an organization. | |
| In a large ecosystem, it's often useful to track the PyTorch and | |
| TorchVision APIs usage. This API provides the similar functionality to the | |
| logging module in the Python stdlib. It can be used for debugging purpose | |
| to log which methods are used and by default it is inactive, unless the user | |
| manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_. | |
| Please note it is triggered only once for the same API call within a process. | |
| It does not collect any data from open-source users since it is no-op by default. | |
| For more information, please refer to | |
| * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; | |
| * Logging policy: https://github.com/pytorch/vision/issues/5052; | |
| Args: | |
| obj (class instance or method): an object to extract info from. | |
| """ | |
| module = obj.__module__ | |
| if not module.startswith("torchvision"): | |
| module = f"torchvision.internal.{module}" | |
| name = obj.__class__.__name__ | |
| if isinstance(obj, FunctionType): | |
| name = obj.__name__ | |
| torch._C._log_api_usage_once(f"{module}.{name}") | |
| def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: | |
| """ | |
| Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. | |
| Otherwise, we will make a tuple of length n, all with value of x. | |
| reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 | |
| Args: | |
| x (Any): input value | |
| n (int): length of the resulting tuple | |
| """ | |
| if isinstance(x, Iterable): | |
| return tuple(x) | |
| return tuple(repeat(x, n)) | |
| def _init_weights(model: nn.Module) -> None: | |
| for m in model.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0.) | |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): | |
| nn.init.constant_(m.weight, 1.) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0.) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.01) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0.) | |
| def interpolate_pos_embed(pos_embed: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[float] = None) -> Tensor: | |
| assert len(pos_embed.shape) == 3, f"Positional embedding should be 3D tensor (C, H, W), but got {pos_embed.shape}." | |
| return F.interpolate( | |
| pos_embed.unsqueeze(0), | |
| size=size, | |
| scale_factor=scale_factor, | |
| mode="bicubic", | |
| align_corners=False, | |
| antialias=True, | |
| ).squeeze(0) | |