ReVQ / revq /models /preprocessor.py
AndyRaoTHU's picture
update
199c8cd
import torch
import torch.nn as nn
class Preprocessor(nn.Module):
def __init__(self, input_data_size):
super(Preprocessor, self).__init__()
self.input_data_size = input_data_size
C, H, W = input_data_size
self.data_dim = C * H * W
self.register_buffer("mean", torch.empty(self.data_dim))
self.register_buffer("std", torch.empty(self.data_dim))
def prepare(self, mean, var):
self.mean = mean.to(self.mean.device)
self.std = torch.sqrt(var.to(self.std.device))
def forward(self, x):
# normalize: (B, C, H, W) -> (B, C, H, W)
B, C, H, W = x.size()
x = x.view(B, self.data_dim) # (B, C*H*W)
x = (x - self.mean) / self.std
x = x.view(B, C, H, W).detach()
return x
def inverse(self, x):
# un-normalize: (B, C, H, W) -> (B, C, H, W)
B = x.size(0)
C, H, W = self.input_data_size
x = x.view(B, self.data_dim) # (B, C*H*W)
x = x * self.std + self.mean
x = x.view(B, C, H, W)
return x