Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from MT import FeatureTransformer | |
| from torch.cuda.amp import autocast as autocast | |
| from flow_tools import viz_img_seq, save_img_seq, plt_show_img_flow | |
| from copy import deepcopy | |
| from V1 import V1 | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from PIL import Image | |
| def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): | |
| if isReLU: | |
| return nn.Sequential( | |
| nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | |
| dilation=dilation, | |
| padding=((kernel_size - 1) * dilation) // 2, bias=True), | |
| nn.GELU() | |
| ) | |
| else: | |
| return nn.Sequential( | |
| nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | |
| dilation=dilation, | |
| padding=((kernel_size - 1) * dilation) // 2, bias=True) | |
| ) | |
| def plt_attention(attention, h, w): | |
| col = len(attention) // 2 | |
| fig = plt.figure(figsize=(10, 8)) | |
| for i in range(len(attention)): | |
| viz = attention[i][0, :, :, h, w].detach().cpu().numpy() | |
| # viz = viz[7:-7, 7:-7] | |
| if i == 0: | |
| viz_all = viz | |
| else: | |
| viz_all = viz_all + viz | |
| ax1 = fig.add_subplot(2, col, i + 1) | |
| img = ax1.imshow(viz, cmap="rainbow", interpolation="bilinear") | |
| ax1.scatter(w, h, color='grey', s=300, alpha=0.5) | |
| ax1.scatter(w, h, color='red', s=150, alpha=0.5) | |
| plt.title(" Iteration %d" % (i + 1)) | |
| if i == len(attention) - 1: | |
| plt.title(" Final Iteration") | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| # tight layout | |
| plt.tight_layout() | |
| # save the figure | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close() | |
| # convert the figure to an array | |
| img = Image.open(buf) | |
| img = np.array(img) | |
| return img | |
| class FlowDecoder(nn.Module): | |
| # can reduce 25% of training time. | |
| def __init__(self, ch_in): | |
| super(FlowDecoder, self).__init__() | |
| self.conv1 = conv(ch_in, 256, kernel_size=1) | |
| self.conv2 = conv(256, 128, kernel_size=1) | |
| self.conv3 = conv(256 + 128, 96, kernel_size=1) | |
| self.conv4 = conv(96 + 128, 64, kernel_size=1) | |
| self.conv5 = conv(96 + 64, 32, kernel_size=1) | |
| self.feat_dim = 32 | |
| self.predict_flow = conv(64 + 32, 2, isReLU=False) | |
| def forward(self, x): | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(x1) | |
| x3 = self.conv3(torch.cat([x1, x2], dim=1)) | |
| x4 = self.conv4(torch.cat([x2, x3], dim=1)) | |
| x5 = self.conv5(torch.cat([x3, x4], dim=1)) | |
| flow = self.predict_flow(torch.cat([x4, x5], dim=1)) | |
| return flow | |
| class FFV1DNN(nn.Module): | |
| def __init__(self, | |
| num_scales=8, | |
| num_cells=256, | |
| upsample_factor=8, | |
| feature_channels=256, | |
| scale_factor=16, | |
| num_layers=6, | |
| ): | |
| super(FFV1DNN, self).__init__() | |
| self.ffv1 = V1(spatial_num=num_cells // num_scales, scale_num=num_scales, scale_factor=scale_factor, | |
| kernel_radius=7, num_ft=num_cells // num_scales, | |
| kernel_size=6, average_time=True) | |
| self.v1_kz = 7 | |
| self.scale_factor = scale_factor | |
| scale_each_level = np.exp(1 / (num_scales - 1) * np.log(1 / scale_factor)) | |
| self.scale_num = num_scales | |
| self.scale_each_level = scale_each_level | |
| v1_channel = self.ffv1.num_after_st | |
| self.num_scales = num_scales | |
| self.MT_channel = feature_channels | |
| assert self.MT_channel == v1_channel | |
| self.feature_channels = feature_channels | |
| self.upsample_factor = upsample_factor | |
| self.num_layers = num_layers | |
| # convex upsampling: concat feature0 and flow as input | |
| self.upsampler_1 = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 256, 3, 1, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, upsample_factor ** 2 * 9, 3, 1, 1)) | |
| self.decoder = FlowDecoder(feature_channels) | |
| self.conv_feat = nn.ModuleList([conv(v1_channel, feature_channels, 1) for i in range(num_scales)]) | |
| self.MT = FeatureTransformer(d_model=feature_channels, num_layers=self.num_layers) | |
| # 2*2*8*scale` | |
| def upsample_flow(self, flow, feature, upsampler=None, bilinear=False, upsample_factor=4): | |
| if bilinear: | |
| up_flow = F.interpolate(flow, scale_factor=upsample_factor, | |
| mode='bilinear', align_corners=True) * upsample_factor | |
| else: | |
| # convex upsampling | |
| concat = torch.cat((flow, feature), dim=1) | |
| mask = upsampler(concat) | |
| b, flow_channel, h, w = flow.shape | |
| mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] | |
| mask = torch.softmax(mask, dim=2) | |
| up_flow = F.unfold(upsample_factor * flow, [3, 3], padding=1) | |
| up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] | |
| up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] | |
| up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] | |
| up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, | |
| upsample_factor * w) # [B, 2, K*H, K*W] | |
| return up_flow | |
| def forward(self, image_list, mix_enable=True, layer=6): | |
| if layer is not None: | |
| self.MT.num_layers = layer | |
| self.num_layers = layer | |
| results_dict = {} | |
| padding = self.v1_kz * self.scale_factor | |
| with torch.no_grad(): | |
| if image_list[0].max() > 10: | |
| image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 | |
| if image_list[0].shape[1] == 3: | |
| # convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114 | |
| image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in | |
| image_list] | |
| image_list = [img.unsqueeze(1) for img in image_list] | |
| B, _, H, W = image_list[0].shape | |
| MT_size = (H // 8, W // 8) | |
| with autocast(enabled=mix_enable): | |
| # with torch.no_grad(): # TODO: only for test wheather a trainable V1 is needed. | |
| st_component = self.ffv1(image_list) | |
| # viz_img_seq(image_scale, if_debug=True) | |
| if self.num_layers == 0: | |
| motion_feature = [st_component] | |
| flows = [self.decoder(feature) for feature in motion_feature] | |
| flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
| results_dict["flow_seq"] = flows_up | |
| return results_dict | |
| motion_feature, attn = self.MT.forward_save_mem(st_component) | |
| flow_v1 = self.decoder(st_component) | |
| flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] | |
| flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
| flows_up = [flows_bi[0]] + \ | |
| [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for | |
| flows, attn in zip(flows[1:], attn)] | |
| assert len(flows_bi) == len(flows_up) | |
| results_dict["flow_seq"] = flows_up | |
| results_dict["flow_seq_bi"] = flows_bi | |
| return results_dict | |
| def forward_test(self, image_list, mix_enable=True, layer=6): | |
| if layer is not None: | |
| self.MT.num_layers = layer | |
| self.num_layers = layer | |
| results_dict = {} | |
| padding = self.v1_kz * self.scale_factor | |
| with torch.no_grad(): | |
| if image_list[0].max() > 10: | |
| image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 | |
| B, _, H, W = image_list[0].shape | |
| MT_size = (H // 8, W // 8) | |
| with autocast(enabled=mix_enable): | |
| st_component = self.ffv1(image_list) | |
| # viz_img_seq(image_scale, if_debug=True) | |
| if self.num_layers == 0: | |
| motion_feature = [st_component] | |
| flows = [self.decoder(feature) for feature in motion_feature] | |
| flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
| results_dict["flow_seq"] = flows_up | |
| return results_dict | |
| motion_feature, attn, _ = self.MT.forward_save_mem(st_component) | |
| flow_v1 = self.decoder(st_component) | |
| flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] | |
| flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
| flows_up = [flows_bi[0]] + \ | |
| [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for | |
| flows, attn in zip(flows[1:], attn)] | |
| assert len(flows_bi) == len(flows_up) | |
| results_dict["flow_seq"] = flows_up | |
| results_dict["flow_seq_bi"] = flows_bi | |
| return results_dict | |
| def forward_viz(self, image_list, layer=None, x=50, y=50): | |
| x = x / 100 | |
| y = y / 100 | |
| if layer is not None: | |
| self.MT.num_layers = layer | |
| results_dict = {} | |
| padding = self.v1_kz * self.scale_factor | |
| with torch.no_grad(): | |
| if image_list[0].max() > 10: | |
| image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 | |
| if image_list[0].shape[1] == 3: | |
| # convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114 | |
| image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in | |
| image_list] | |
| image_list = [img.unsqueeze(1) for img in image_list] | |
| image_list_ori = deepcopy(image_list) | |
| B, _, H, W = image_list[0].shape | |
| MT_size = (H // 8, W // 8) | |
| with autocast(enabled=True): | |
| st_component = self.ffv1(image_list) | |
| activation = self.ffv1.visualize_activation(st_component) | |
| # viz_img_seq(image_scale, if_debug=True) | |
| motion_feature, attn, attn_viz = self.MT(st_component) | |
| flow_v1 = self.decoder(st_component) | |
| flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] | |
| flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
| flows_up = [flows_bi[0]] + \ | |
| [self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for | |
| flows, attn in zip(flows[1:], attn)] | |
| assert len(flows_bi) == len(flows_up) | |
| results_dict["flow_seq"] = flows_up | |
| flows_up = flows_up[:-1] | |
| attn_viz = attn_viz | |
| print(len(flows_up), len(attn_viz)) | |
| flow = plt_show_img_flow(image_list_ori, flows_up) | |
| h = int(MT_size[0] * y) | |
| w = int(MT_size[1] * x) | |
| attention = plt_attention(attn_viz, h=h, w=w) | |
| print("done") | |
| results_dict["activation"] = activation | |
| results_dict["attention"] = attention | |
| results_dict["flow"] = flow | |
| plt.clf() | |
| plt.cla() | |
| plt.close() | |
| return results_dict | |
| def num_parameters(self): | |
| return sum( | |
| [p.data.nelement() if p.requires_grad else 0 for p in self.parameters()]) | |
| def init_weights(self): | |
| for layer in self.named_modules(): | |
| if isinstance(layer, nn.Conv2d): | |
| nn.init.kaiming_normal_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.constant_(layer.bias, 0) | |
| if isinstance(layer, nn.Conv1d): | |
| nn.init.kaiming_normal_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.constant_(layer.bias, 0) | |
| elif isinstance(layer, nn.ConvTranspose2d): | |
| nn.init.kaiming_normal_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.constant_(layer.bias, 0) | |
| def demo(file=None): | |
| import time | |
| from utils import torch_utils as utils | |
| frame_list = [torch.randn([4, 1, 512, 512], device="cuda")] * 11 | |
| model = FFV1DNN(num_scales=8, scale_factor=16, num_cells=256, upsample_factor=8, num_layers=6, | |
| feature_channels=256).cuda() | |
| if file is not None: | |
| model = utils.restore_model(model, file) | |
| print(model.num_parameters()) | |
| for i in range(100): | |
| start = time.time() | |
| output = model.forward_viz(frame_list, layer=7) | |
| # print(output["flow_seq"][-1]) | |
| torch.mean(output["flow_seq"][-1]).backward() | |
| print(torch.any(torch.isnan(output["flow_seq"][-1]))) | |
| end = time.time() | |
| print(end - start) | |
| print("#================================++#") | |
| if __name__ == '__main__': | |
| FFV1DNN.demo(None) | |