Hongyang Li
Upload 13 files
8e92669 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .DoubleConv import DoubleConv
class DetectorHead(nn.Module):
"""关键点检测头部"""
def __init__(self, channel_list, n_class):
"""
in_channels_list: 各层输入通道数 [c3+c4, c2+c3, c1+c2]
"""
super().__init__()
c1, c2, c3, c4, c5, d1, d2 = channel_list
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
# 上采样路径
self.dconv_up3 = DoubleConv(c3 + c4, c3)
self.dconv_up2 = DoubleConv(c2 + c3, c2)
self.dconv_up1 = DoubleConv(c1 + c2, c1)
self.conv_last = nn.Conv2d(c1, n_class, kernel_size=1)
def forward(self, x, convs):
"""
x: 编码器输出特征
convs: 中间特征 [conv1, conv2, conv3]
"""
# 上采样阶段1
cPa = self.upsample(x)
cPa = torch.cat([cPa, convs[2]], dim=1) # 连接conv3
cPa = self.dconv_up3(cPa)
# 上采样阶段2
cPa = self.upsample(cPa)
cPa = torch.cat([cPa, convs[1]], dim=1) # 连接conv2
cPa = self.dconv_up2(cPa)
# 上采样阶段3
cPa = self.upsample(cPa)
cPa = torch.cat([cPa, convs[0]], dim=1) # 连接conv1
cPa = self.dconv_up1(cPa)
# 最终输出
semi = self.conv_last(cPa)
return torch.sigmoid(semi)