Spaces:
Running
on
Zero
Running
on
Zero
| from timm import create_model | |
| from torch import nn, Tensor | |
| from typing import Optional | |
| from functools import partial | |
| from ..utils import _get_activation, _get_norm_layer, ConvUpsample, ConvDownsample | |
| from ..utils import LightConvUpsample, LightConvDownsample, LighterConvUpsample, LighterConvDownsample | |
| from ..utils import ConvRefine, LightConvRefine, LighterConvRefine | |
| regular_models = [ | |
| "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", | |
| "convnext_nano", "convnext_tiny", "convnext_small", "convnext_base", | |
| "mobilenetv4_conv_large", | |
| ] | |
| heavy_models = [ | |
| "convnext_large", "convnext_xlarge", "convnext_xxlarge", | |
| ] | |
| light_models = [ | |
| "mobilenetv1_100", "mobilenetv1_125", | |
| "mobilenetv2_100", "mobilenetv2_140", | |
| "mobilenetv3_large_100", | |
| "mobilenetv4_conv_medium", | |
| ] | |
| lighter_models = [ | |
| "mobilenetv2_050", | |
| "mobilenetv3_small_050", "mobilenetv3_small_075", "mobilenetv3_small_100", | |
| "mobilenetv4_conv_small_050", "mobilenetv4_conv_small" | |
| ] | |
| supported_models = regular_models + heavy_models + light_models + lighter_models | |
| refiner_in_channels = { | |
| # ResNet | |
| "resnet18": 512, | |
| "resnet34": 512, | |
| "resnet50": 2048, | |
| "resnet101": 2048, | |
| "resnet152": 2048, | |
| # ConvNeXt | |
| "convnext_nano": 640, | |
| "convnext_tiny": 768, | |
| "convnext_small": 768, | |
| "convnext_base": 1024, | |
| "convnext_large": 1536, | |
| "convnext_xlarge": 2048, | |
| "convnext_xxlarge": 3072, | |
| # MobileNet V1 | |
| "mobilenetv1_100": 1024, | |
| "mobilenetv1_125": 1280, | |
| # MobileNet V2 | |
| "mobilenetv2_050": 160, | |
| "mobilenetv2_100": 320, | |
| "mobilenetv2_140": 448, | |
| # MobileNet V3 | |
| "mobilenetv3_small_050": 288, | |
| "mobilenetv3_small_075": 432, | |
| "mobilenetv3_small_100": 576, | |
| "mobilenetv3_large_100": 960, | |
| # MobileNet V4 | |
| "mobilenetv4_conv_small_050": 480, | |
| "mobilenetv4_conv_small": 960, | |
| "mobilenetv4_conv_medium": 960, | |
| "mobilenetv4_conv_large": 960, | |
| } | |
| refiner_out_channels = { | |
| # ResNet | |
| "resnet18": 512, | |
| "resnet34": 512, | |
| "resnet50": 2048, | |
| "resnet101": 2048, | |
| "resnet152": 2048, | |
| # ConvNeXt | |
| "convnext_nano": 640, | |
| "convnext_tiny": 768, | |
| "convnext_small": 768, | |
| "convnext_base": 1024, | |
| "convnext_large": 1536, | |
| "convnext_xlarge": 2048, | |
| "convnext_xxlarge": 3072, | |
| # MobileNet V1 | |
| "mobilenetv1_100": 512, | |
| "mobilenetv1_125": 640, | |
| # MobileNet V2 | |
| "mobilenetv2_050": 160, | |
| "mobilenetv2_100": 320, | |
| "mobilenetv2_140": 448, | |
| # MobileNet V3 | |
| "mobilenetv3_small_050": 288, | |
| "mobilenetv3_small_075": 432, | |
| "mobilenetv3_small_100": 576, | |
| "mobilenetv3_large_100": 480, | |
| # MobileNet V4 | |
| "mobilenetv4_conv_small_050": 480, | |
| "mobilenetv4_conv_small": 960, | |
| "mobilenetv4_conv_medium": 960, | |
| "mobilenetv4_conv_large": 960, | |
| } | |
| groups = { | |
| # ResNet | |
| "resnet18": 1, | |
| "resnet34": 1, | |
| "resnet50": refiner_in_channels["resnet50"] // 512, | |
| "resnet101": refiner_in_channels["resnet101"] // 512, | |
| "resnet152": refiner_in_channels["resnet152"] // 512, | |
| # ConvNeXt | |
| "convnext_nano": 8, | |
| "convnext_tiny": 8, | |
| "convnext_small": 8, | |
| "convnext_base": 8, | |
| "convnext_large": refiner_in_channels["convnext_large"] // 512, | |
| "convnext_xlarge": refiner_in_channels["convnext_xlarge"] // 512, | |
| "convnext_xxlarge": refiner_in_channels["convnext_xxlarge"] // 512, | |
| # MobileNet V1 | |
| "mobilenetv1_100": None, | |
| "mobilenetv1_125": None, | |
| # MobileNet V2 | |
| "mobilenetv2_050": None, | |
| "mobilenetv2_100": None, | |
| "mobilenetv2_140": None, | |
| # MobileNet V3 | |
| "mobilenetv3_small_050": None, | |
| "mobilenetv3_small_075": None, | |
| "mobilenetv3_small_100": None, | |
| "mobilenetv3_large_100": None, | |
| # MobileNet V4 | |
| "mobilenetv4_conv_small_050": None, | |
| "mobilenetv4_conv_small": None, | |
| "mobilenetv4_conv_medium": None, | |
| "mobilenetv4_conv_large": 1, | |
| } | |
| class TIMMModel(nn.Module): | |
| def __init__( | |
| self, | |
| model_name: str, | |
| block_size: Optional[int] = None, | |
| norm: str = "none", | |
| act: str = "none" | |
| ) -> None: | |
| super().__init__() | |
| assert model_name in supported_models, f"Backbone {model_name} not supported. Supported models are {supported_models}" | |
| assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." | |
| self.model_name = model_name | |
| # self.encoder = create_model(model_name, pretrained=True, features_only=True, out_indices=[-1]) | |
| self.encoder = create_model(model_name, pretrained=False, features_only=True, out_indices=[-1]) | |
| self.encoder_channels = self.encoder.feature_info.channels()[-1] | |
| self.encoder_reduction = self.encoder.feature_info.reduction()[-1] | |
| self.block_size = block_size if block_size is not None else self.encoder_reduction | |
| if model_name in lighter_models: | |
| upsample_block = LighterConvUpsample | |
| downsample_block = LighterConvDownsample | |
| decoder_block = LighterConvRefine | |
| elif model_name in light_models: | |
| upsample_block = LightConvUpsample | |
| downsample_block = LightConvDownsample | |
| decoder_block = LightConvRefine | |
| else: | |
| upsample_block = partial(ConvUpsample, groups=groups[model_name]) | |
| downsample_block = partial(ConvDownsample, groups=groups[model_name]) | |
| decoder_block = partial(ConvRefine, groups=groups[model_name]) | |
| if norm == "bn": | |
| norm_layer = nn.BatchNorm2d | |
| elif norm == "ln": | |
| norm_layer = nn.LayerNorm | |
| else: | |
| norm_layer = _get_norm_layer(self.encoder) | |
| if act == "relu": | |
| activation = nn.ReLU(inplace=True) | |
| elif act == "gelu": | |
| activation = nn.GELU() | |
| else: | |
| activation = _get_activation(self.encoder) | |
| if self.block_size > self.encoder_reduction: | |
| if self.block_size > self.encoder_reduction * 2: | |
| assert self.block_size == self.encoder_reduction * 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." | |
| self.refiner = nn.Sequential( | |
| downsample_block( | |
| in_channels=self.encoder_channels, | |
| out_channels=refiner_in_channels[self.model_name], | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| downsample_block( | |
| in_channels=refiner_in_channels[self.model_name], | |
| out_channels=refiner_out_channels[self.model_name], | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| ) | |
| else: | |
| assert self.block_size == self.encoder_reduction * 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." | |
| self.refiner = downsample_block( | |
| in_channels=self.encoder_channels, | |
| out_channels=refiner_out_channels[self.model_name], | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| self.refiner_channels = refiner_out_channels[self.model_name] | |
| elif self.block_size < self.encoder_reduction: | |
| if self.block_size < self.encoder_reduction // 2: | |
| assert self.block_size == self.encoder_reduction // 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." | |
| self.refiner = nn.Sequential( | |
| upsample_block( | |
| in_channels=self.encoder_channels, | |
| out_channels=refiner_in_channels[self.model_name], | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| upsample_block( | |
| in_channels=refiner_in_channels[self.model_name], | |
| out_channels=refiner_out_channels[self.model_name], | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| ) | |
| else: | |
| assert self.block_size == self.encoder_reduction // 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." | |
| self.refiner = upsample_block( | |
| in_channels=self.encoder_channels, | |
| out_channels=refiner_out_channels[self.model_name], | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| self.refiner_channels = refiner_out_channels[self.model_name] | |
| else: | |
| self.refiner = nn.Identity() | |
| self.refiner_channels = self.encoder_channels | |
| self.refiner_reduction = self.block_size | |
| if self.refiner_channels <= 256: | |
| self.decoder = nn.Identity() | |
| self.decoder_channels = self.refiner_channels | |
| elif self.refiner_channels <= 512: | |
| self.decoder = decoder_block( | |
| in_channels=self.refiner_channels, | |
| out_channels=self.refiner_channels // 2, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ) | |
| self.decoder_channels = self.refiner_channels // 2 | |
| elif self.refiner_channels <= 1024: | |
| self.decoder = nn.Sequential( | |
| decoder_block( | |
| in_channels=self.refiner_channels, | |
| out_channels=self.refiner_channels // 2, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| decoder_block( | |
| in_channels=self.refiner_channels // 2, | |
| out_channels=self.refiner_channels // 4, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| ) | |
| self.decoder_channels = self.refiner_channels // 4 | |
| else: | |
| self.decoder = nn.Sequential( | |
| decoder_block( | |
| in_channels=self.refiner_channels, | |
| out_channels=self.refiner_channels // 2, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| decoder_block( | |
| in_channels=self.refiner_channels // 2, | |
| out_channels=self.refiner_channels // 4, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| decoder_block( | |
| in_channels=self.refiner_channels // 4, | |
| out_channels=self.refiner_channels // 8, | |
| norm_layer=norm_layer, | |
| activation=activation, | |
| ), | |
| ) | |
| self.decoder_channels = self.refiner_channels // 8 | |
| self.decoder_reduction = self.refiner_reduction | |
| def encode(self, x: Tensor) -> Tensor: | |
| return self.encoder(x)[0] | |
| def refine(self, x: Tensor) -> Tensor: | |
| return self.refiner(x) | |
| def decode(self, x: Tensor) -> Tensor: | |
| return self.decoder(x) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.encode(x) | |
| x = self.refine(x) | |
| x = self.decode(x) | |
| return x | |
| def _timm_model(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> TIMMModel: | |
| return TIMMModel(model_name, block_size=block_size, norm=norm, act=act) | |