Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class TFSamepaddingLayer(nn.Module): | |
| """To align with tf `same` padding. | |
| Putting this before any conv layer that need padding | |
| Assuming kernel has Height == Width for simplicity | |
| """ | |
| def __init__(self, ksize, stride): | |
| super(TFSamepaddingLayer, self).__init__() | |
| self.ksize = ksize | |
| self.stride = stride | |
| def forward(self, x): | |
| if x.shape[2] % self.stride == 0: | |
| pad = max(self.ksize - self.stride, 0) | |
| else: | |
| pad = max(self.ksize - (x.shape[2] % self.stride), 0) | |
| if pad % 2 == 0: | |
| pad_val = pad // 2 | |
| padding = (pad_val, pad_val, pad_val, pad_val) | |
| else: | |
| pad_val_start = pad // 2 | |
| pad_val_end = pad - pad_val_start | |
| padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) | |
| # print(x.shape, padding) | |
| x = F.pad(x, padding, "constant", 0) | |
| # print(x.shape) | |
| return x | |