Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Residual block as defined in: | |
| # He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning | |
| # for image recognition." In Proceedings of the IEEE conference on computer vision | |
| # and pattern recognition, pp. 770-778. 2016. | |
| # | |
| # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import torch | |
| import torch.nn as nn | |
| from collections import OrderedDict | |
| from models.utils.tf_utils import TFSamepaddingLayer | |
| class ResidualBlock(nn.Module): | |
| """Residual block as defined in: | |
| He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning | |
| for image recognition." In Proceedings of the IEEE conference on computer vision | |
| and pattern recognition, pp. 770-778. 2016. | |
| """ | |
| def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): | |
| super(ResidualBlock, self).__init__() | |
| assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" | |
| self.nr_unit = unit_count | |
| self.in_ch = in_ch | |
| self.unit_ch = unit_ch | |
| # ! For inference only so init values for batchnorm may not match tensorflow | |
| unit_in_ch = in_ch | |
| self.units = nn.ModuleList() | |
| for idx in range(unit_count): | |
| unit_layer = [ | |
| ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), | |
| ("preact/relu", nn.ReLU(inplace=True)), | |
| ( | |
| "conv1", | |
| nn.Conv2d( | |
| unit_in_ch, | |
| unit_ch[0], | |
| unit_ksize[0], | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| ), | |
| ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), | |
| ("conv1/relu", nn.ReLU(inplace=True)), | |
| ( | |
| "conv2/pad", | |
| TFSamepaddingLayer( | |
| ksize=unit_ksize[1], stride=stride if idx == 0 else 1 | |
| ), | |
| ), | |
| ( | |
| "conv2", | |
| nn.Conv2d( | |
| unit_ch[0], | |
| unit_ch[1], | |
| unit_ksize[1], | |
| stride=stride if idx == 0 else 1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| ), | |
| ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), | |
| ("conv2/relu", nn.ReLU(inplace=True)), | |
| ( | |
| "conv3", | |
| nn.Conv2d( | |
| unit_ch[1], | |
| unit_ch[2], | |
| unit_ksize[2], | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| ), | |
| ] | |
| # * has bna to conclude each previous block so | |
| # * must not put preact for the first unit of this block | |
| unit_layer = unit_layer if idx != 0 else unit_layer[2:] | |
| self.units.append(nn.Sequential(OrderedDict(unit_layer))) | |
| unit_in_ch = unit_ch[-1] | |
| if in_ch != unit_ch[-1] or stride != 1: | |
| self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) | |
| else: | |
| self.shortcut = None | |
| self.blk_bna = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), | |
| ("relu", nn.ReLU(inplace=True)), | |
| ] | |
| ) | |
| ) | |
| def out_ch(self): | |
| return self.unit_ch[-1] | |
| def init_weights(self): | |
| """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers""" | |
| for m in self.modules(): | |
| classname = m.__class__.__name__ | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if "norm" in classname.lower(): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| if "linear" in classname.lower(): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, prev_feat, freeze=False): | |
| if self.shortcut is None: | |
| shortcut = prev_feat | |
| else: | |
| shortcut = self.shortcut(prev_feat) | |
| for idx in range(0, len(self.units)): | |
| new_feat = prev_feat | |
| if self.training: | |
| with torch.set_grad_enabled(not freeze): | |
| new_feat = self.units[idx](new_feat) | |
| else: | |
| new_feat = self.units[idx](new_feat) | |
| prev_feat = new_feat + shortcut | |
| shortcut = prev_feat | |
| feat = self.blk_bna(prev_feat) | |
| return feat | |