Spaces:
Running
on
Zero
Running
on
Zero
| from torch import nn, Tensor | |
| from torch.nn import functional as F | |
| from typing import Union | |
| from functools import partial | |
| from .utils import _init_weights | |
| from .refine import ConvRefine, LightConvRefine, LighterConvRefine | |
| class ConvUpsample(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| scale_factor: int = 2, | |
| norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, | |
| activation: nn.Module = nn.ReLU(inplace=True), | |
| groups: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" | |
| self.scale_factor = scale_factor | |
| self.upsample = partial( | |
| F.interpolate, | |
| scale_factor=scale_factor, | |
| mode="bilinear", | |
| align_corners=False, | |
| recompute_scale_factor=False, | |
| antialias=False, | |
| ) if scale_factor > 1 else nn.Identity() | |
| self.refine = ConvRefine( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| groups=groups, | |
| ) | |
| self.apply(_init_weights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.upsample(x) | |
| x = self.refine(x) | |
| return x | |
| class LightConvUpsample(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| scale_factor: int = 2, | |
| norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, | |
| activation: nn.Module = nn.ReLU(inplace=True), | |
| ) -> None: | |
| super().__init__() | |
| assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" | |
| self.scale_factor = scale_factor | |
| self.upsample = partial( | |
| F.interpolate, | |
| scale_factor=scale_factor, | |
| mode="bilinear", | |
| align_corners=False, | |
| recompute_scale_factor=False, | |
| antialias=False, | |
| ) if scale_factor > 1 else nn.Identity() | |
| self.refine = LightConvRefine( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| self.apply(_init_weights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.upsample(x) | |
| x = self.refine(x) | |
| return x | |
| class LighterConvUpsample(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| scale_factor: int = 2, | |
| norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, | |
| activation: nn.Module = nn.ReLU(inplace=True), | |
| ) -> None: | |
| super().__init__() | |
| assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" | |
| self.scale_factor = scale_factor | |
| self.upsample = partial( | |
| F.interpolate, | |
| scale_factor=scale_factor, | |
| mode="bilinear", | |
| align_corners=False, | |
| recompute_scale_factor=False, | |
| antialias=False, | |
| ) if scale_factor > 1 else nn.Identity() | |
| self.refine = LighterConvRefine( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| self.apply(_init_weights) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.upsample(x) | |
| x = self.refine(x) | |
| return x | |