Hongyang Li
Upload 13 files
8e92669 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class DescriptorHead(nn.Module):
"""描述符生成头"""
def __init__(self, channel_list):
super().__init__()
c1, c2, c3, c4, c5, d1, d2 = channel_list
self.relu = nn.ReLU(inplace=True)
self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 尺寸不变
# ---------------------------------------------------------------------------------------
# 原本的代码这里尺寸有点问题,96下采样过后是47不是48,需要padding=1
# self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=0) # 尺寸缩小1/2
self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=1) # 尺寸缩小1/2
# ---------------------------------------------------------------------------------------
self.convDc = torch.nn.Conv2d(d1, d2, kernel_size=1, stride=1, padding=0) # 尺寸不变
# self.trans_conv = nn.ConvTranspose2d(d1, d2, 2, stride=2) # 尺寸放大2倍
def forward(self, x):
cDa = self.relu(self.convDa(x))
cDb = self.relu(self.convDb(cDa))
desc = self.convDc(cDb)
# 归一化描述符
dn = torch.norm(desc, p=2, dim=1)
desc = desc.div(torch.unsqueeze(dn, 1))
# return self.trans_conv(desc)
return desc