Spaces:
Sleeping
Sleeping
File size: 1,065 Bytes
199c8cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
import torch.nn as nn
class Preprocessor(nn.Module):
def __init__(self, input_data_size):
super(Preprocessor, self).__init__()
self.input_data_size = input_data_size
C, H, W = input_data_size
self.data_dim = C * H * W
self.register_buffer("mean", torch.empty(self.data_dim))
self.register_buffer("std", torch.empty(self.data_dim))
def prepare(self, mean, var):
self.mean = mean.to(self.mean.device)
self.std = torch.sqrt(var.to(self.std.device))
def forward(self, x):
# normalize: (B, C, H, W) -> (B, C, H, W)
B, C, H, W = x.size()
x = x.view(B, self.data_dim) # (B, C*H*W)
x = (x - self.mean) / self.std
x = x.view(B, C, H, W).detach()
return x
def inverse(self, x):
# un-normalize: (B, C, H, W) -> (B, C, H, W)
B = x.size(0)
C, H, W = self.input_data_size
x = x.view(B, self.data_dim) # (B, C*H*W)
x = x * self.std + self.mean
x = x.view(B, C, H, W)
return x |