Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class Preprocessor(nn.Module): | |
| def __init__(self, input_data_size): | |
| super(Preprocessor, self).__init__() | |
| self.input_data_size = input_data_size | |
| C, H, W = input_data_size | |
| self.data_dim = C * H * W | |
| self.register_buffer("mean", torch.empty(self.data_dim)) | |
| self.register_buffer("std", torch.empty(self.data_dim)) | |
| def prepare(self, mean, var): | |
| self.mean = mean.to(self.mean.device) | |
| self.std = torch.sqrt(var.to(self.std.device)) | |
| def forward(self, x): | |
| # normalize: (B, C, H, W) -> (B, C, H, W) | |
| B, C, H, W = x.size() | |
| x = x.view(B, self.data_dim) # (B, C*H*W) | |
| x = (x - self.mean) / self.std | |
| x = x.view(B, C, H, W).detach() | |
| return x | |
| def inverse(self, x): | |
| # un-normalize: (B, C, H, W) -> (B, C, H, W) | |
| B = x.size(0) | |
| C, H, W = self.input_data_size | |
| x = x.view(B, self.data_dim) # (B, C*H*W) | |
| x = x * self.std + self.mean | |
| x = x.view(B, C, H, W) | |
| return x |