Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class DescriptorHead(nn.Module): | |
| """描述符生成头""" | |
| def __init__(self, channel_list): | |
| super().__init__() | |
| c1, c2, c3, c4, c5, d1, d2 = channel_list | |
| self.relu = nn.ReLU(inplace=True) | |
| self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 尺寸不变 | |
| # --------------------------------------------------------------------------------------- | |
| # 原本的代码这里尺寸有点问题,96下采样过后是47不是48,需要padding=1 | |
| # self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=0) # 尺寸缩小1/2 | |
| self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=1) # 尺寸缩小1/2 | |
| # --------------------------------------------------------------------------------------- | |
| self.convDc = torch.nn.Conv2d(d1, d2, kernel_size=1, stride=1, padding=0) # 尺寸不变 | |
| # self.trans_conv = nn.ConvTranspose2d(d1, d2, 2, stride=2) # 尺寸放大2倍 | |
| def forward(self, x): | |
| cDa = self.relu(self.convDa(x)) | |
| cDb = self.relu(self.convDb(cDa)) | |
| desc = self.convDc(cDb) | |
| # 归一化描述符 | |
| dn = torch.norm(desc, p=2, dim=1) | |
| desc = desc.div(torch.unsqueeze(dn, 1)) | |
| # return self.trans_conv(desc) | |
| return desc |