Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .DoubleConv import DoubleConv | |
| class DetectorHead(nn.Module): | |
| """关键点检测头部""" | |
| def __init__(self, channel_list, n_class): | |
| """ | |
| in_channels_list: 各层输入通道数 [c3+c4, c2+c3, c1+c2] | |
| """ | |
| super().__init__() | |
| c1, c2, c3, c4, c5, d1, d2 = channel_list | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| # self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True) | |
| # 上采样路径 | |
| self.dconv_up3 = DoubleConv(c3 + c4, c3) | |
| self.dconv_up2 = DoubleConv(c2 + c3, c2) | |
| self.dconv_up1 = DoubleConv(c1 + c2, c1) | |
| self.conv_last = nn.Conv2d(c1, n_class, kernel_size=1) | |
| def forward(self, x, convs): | |
| """ | |
| x: 编码器输出特征 | |
| convs: 中间特征 [conv1, conv2, conv3] | |
| """ | |
| # 上采样阶段1 | |
| cPa = self.upsample(x) | |
| cPa = torch.cat([cPa, convs[2]], dim=1) # 连接conv3 | |
| cPa = self.dconv_up3(cPa) | |
| # 上采样阶段2 | |
| cPa = self.upsample(cPa) | |
| cPa = torch.cat([cPa, convs[1]], dim=1) # 连接conv2 | |
| cPa = self.dconv_up2(cPa) | |
| # 上采样阶段3 | |
| cPa = self.upsample(cPa) | |
| cPa = torch.cat([cPa, convs[0]], dim=1) # 连接conv1 | |
| cPa = self.dconv_up1(cPa) | |
| # 最终输出 | |
| semi = self.conv_last(cPa) | |
| return torch.sigmoid(semi) |