File size: 1,065 Bytes
199c8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn

class Preprocessor(nn.Module):
    def __init__(self, input_data_size):
        super(Preprocessor, self).__init__()
        self.input_data_size = input_data_size
        C, H, W = input_data_size
        self.data_dim = C * H * W
        self.register_buffer("mean", torch.empty(self.data_dim))
        self.register_buffer("std", torch.empty(self.data_dim))
    
    def prepare(self, mean, var):
        self.mean = mean.to(self.mean.device)
        self.std = torch.sqrt(var.to(self.std.device))

    def forward(self, x):
        # normalize: (B, C, H, W) -> (B, C, H, W)
        B, C, H, W = x.size()
        x = x.view(B, self.data_dim) # (B, C*H*W)
        x = (x - self.mean) / self.std
        x = x.view(B, C, H, W).detach()
        return x

    def inverse(self, x):
        # un-normalize: (B, C, H, W) -> (B, C, H, W)
        B = x.size(0)
        C, H, W = self.input_data_size
        x = x.view(B, self.data_dim) # (B, C*H*W)
        x = x * self.std + self.mean
        x = x.view(B, C, H, W)
        return x