Spaces:
Sleeping
Sleeping
Hongyang Li
commited on
Upload 13 files
Browse files- model/net/super_retina.py +47 -0
- model/net/super_retina_blocks/DescriptorHead.py +32 -0
- model/net/super_retina_blocks/DescriptorHead_9_0.py +33 -0
- model/net/super_retina_blocks/DetectorHead.py +48 -0
- model/net/super_retina_blocks/DoubleConv.py +17 -0
- model/net/super_retina_blocks/SharedEncoder.py +41 -0
- model/net/super_retina_blocks/__init__.py +0 -0
- model/net/super_retina_blocks/__pycache__/DescriptorHead.cpython-310.pyc +0 -0
- model/net/super_retina_blocks/__pycache__/DescriptorHead_9_0.cpython-310.pyc +0 -0
- model/net/super_retina_blocks/__pycache__/DetectorHead.cpython-310.pyc +0 -0
- model/net/super_retina_blocks/__pycache__/DoubleConv.cpython-310.pyc +0 -0
- model/net/super_retina_blocks/__pycache__/SharedEncoder.cpython-310.pyc +0 -0
- model/net/super_retina_blocks/__pycache__/__init__.cpython-310.pyc +0 -0
model/net/super_retina.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 这部分是解耦之后重构的SuperRetina模型
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from .super_retina_blocks.SharedEncoder import SharedEncoder
|
| 7 |
+
from .super_retina_blocks.DetectorHead import DetectorHead
|
| 8 |
+
from .super_retina_blocks.DescriptorHead import DescriptorHead
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SuperRetina(nn.Module):
|
| 12 |
+
def __init__(self, config=None, encoder_device='cpu', detector_device='cpu', descriptor_device='cpu', n_class=1):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
self.channel_list = config["channel_list"]
|
| 16 |
+
|
| 17 |
+
self.encoder = SharedEncoder(self.channel_list)
|
| 18 |
+
self.desc_head = DescriptorHead(self.channel_list)
|
| 19 |
+
self.det_head = DetectorHead(self.channel_list, n_class)
|
| 20 |
+
self.encoder_device = encoder_device
|
| 21 |
+
self.detector_device = detector_device
|
| 22 |
+
self.descriptor_device = descriptor_device
|
| 23 |
+
|
| 24 |
+
# 移动到对应设备上
|
| 25 |
+
self.encoder.to(self.encoder_device)
|
| 26 |
+
self.det_head.to(self.detector_device)
|
| 27 |
+
self.desc_head.to(self.descriptor_device)
|
| 28 |
+
|
| 29 |
+
if config is not None:
|
| 30 |
+
self.config = config # 传递config参数
|
| 31 |
+
self.nms_size = config['nms_size']
|
| 32 |
+
self.nms_thresh = config['nms_thresh']
|
| 33 |
+
self.scale = 8
|
| 34 |
+
# self.dice = DiceLoss()
|
| 35 |
+
# self.kernel = get_gaussian_kernel(kernlen=config['gaussian_kernel_size'], nsig=config['gaussian_sigma']).to(device)
|
| 36 |
+
|
| 37 |
+
def forward(self, x): # 定义模型的前向传播过程
|
| 38 |
+
# 特征编码
|
| 39 |
+
enc_out, convs = self.encoder(x.to(self.encoder_device))
|
| 40 |
+
|
| 41 |
+
# 关键点检测
|
| 42 |
+
semi = self.det_head(enc_out.to(self.detector_device), [item.to(self.detector_device) for item in convs])
|
| 43 |
+
|
| 44 |
+
# 描述符生成
|
| 45 |
+
desc = self.desc_head(enc_out.to(self.descriptor_device))
|
| 46 |
+
|
| 47 |
+
return semi, desc
|
model/net/super_retina_blocks/DescriptorHead.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DescriptorHead(nn.Module):
|
| 7 |
+
"""描述符生成头"""
|
| 8 |
+
def __init__(self, channel_list):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
c1, c2, c3, c4, c5, d1, d2 = channel_list
|
| 12 |
+
|
| 13 |
+
self.relu = nn.ReLU(inplace=True)
|
| 14 |
+
self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 尺寸不变
|
| 15 |
+
# ---------------------------------------------------------------------------------------
|
| 16 |
+
# 原本的代码这里尺寸有点问题,96下采样过后是47不是48,需要padding=1
|
| 17 |
+
# self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=0) # 尺寸缩小1/2
|
| 18 |
+
self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=1) # 尺寸缩小1/2
|
| 19 |
+
# ---------------------------------------------------------------------------------------
|
| 20 |
+
self.convDc = torch.nn.Conv2d(d1, d2, kernel_size=1, stride=1, padding=0) # 尺寸不变
|
| 21 |
+
self.trans_conv = nn.ConvTranspose2d(d1, d2, 2, stride=2) # 尺寸放大2倍
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
cDa = self.relu(self.convDa(x))
|
| 25 |
+
cDb = self.relu(self.convDb(cDa))
|
| 26 |
+
desc = self.convDc(cDb)
|
| 27 |
+
|
| 28 |
+
# 归一化描述符
|
| 29 |
+
dn = torch.norm(desc, p=2, dim=1)
|
| 30 |
+
desc = desc.div(torch.unsqueeze(dn, 1))
|
| 31 |
+
|
| 32 |
+
return self.trans_conv(desc)
|
model/net/super_retina_blocks/DescriptorHead_9_0.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DescriptorHead(nn.Module):
|
| 7 |
+
"""描述符生成头"""
|
| 8 |
+
def __init__(self, channel_list):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
c1, c2, c3, c4, c5, d1, d2 = channel_list
|
| 12 |
+
|
| 13 |
+
self.relu = nn.ReLU(inplace=True)
|
| 14 |
+
self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 尺寸不变
|
| 15 |
+
# ---------------------------------------------------------------------------------------
|
| 16 |
+
# 原本的代码这里尺寸有点问题,96下采样过后是47不是48,需要padding=1
|
| 17 |
+
# self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=0) # 尺寸缩小1/2
|
| 18 |
+
self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=1) # 尺寸缩小1/2
|
| 19 |
+
# ---------------------------------------------------------------------------------------
|
| 20 |
+
self.convDc = torch.nn.Conv2d(d1, d2, kernel_size=1, stride=1, padding=0) # 尺寸不变
|
| 21 |
+
# self.trans_conv = nn.ConvTranspose2d(d1, d2, 2, stride=2) # 尺寸放大2倍
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
cDa = self.relu(self.convDa(x))
|
| 25 |
+
cDb = self.relu(self.convDb(cDa))
|
| 26 |
+
desc = self.convDc(cDb)
|
| 27 |
+
|
| 28 |
+
# 归一化描述符
|
| 29 |
+
dn = torch.norm(desc, p=2, dim=1)
|
| 30 |
+
desc = desc.div(torch.unsqueeze(dn, 1))
|
| 31 |
+
|
| 32 |
+
# return self.trans_conv(desc)
|
| 33 |
+
return desc
|
model/net/super_retina_blocks/DetectorHead.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .DoubleConv import DoubleConv
|
| 6 |
+
|
| 7 |
+
class DetectorHead(nn.Module):
|
| 8 |
+
"""关键点检测头部"""
|
| 9 |
+
def __init__(self, channel_list, n_class):
|
| 10 |
+
"""
|
| 11 |
+
in_channels_list: 各层输入通道数 [c3+c4, c2+c3, c1+c2]
|
| 12 |
+
"""
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
c1, c2, c3, c4, c5, d1, d2 = channel_list
|
| 16 |
+
|
| 17 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 18 |
+
# self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
|
| 19 |
+
|
| 20 |
+
# 上采样路径
|
| 21 |
+
self.dconv_up3 = DoubleConv(c3 + c4, c3)
|
| 22 |
+
self.dconv_up2 = DoubleConv(c2 + c3, c2)
|
| 23 |
+
self.dconv_up1 = DoubleConv(c1 + c2, c1)
|
| 24 |
+
self.conv_last = nn.Conv2d(c1, n_class, kernel_size=1)
|
| 25 |
+
|
| 26 |
+
def forward(self, x, convs):
|
| 27 |
+
"""
|
| 28 |
+
x: 编码器输出特征
|
| 29 |
+
convs: 中间特征 [conv1, conv2, conv3]
|
| 30 |
+
"""
|
| 31 |
+
# 上采样阶段1
|
| 32 |
+
cPa = self.upsample(x)
|
| 33 |
+
cPa = torch.cat([cPa, convs[2]], dim=1) # 连接conv3
|
| 34 |
+
cPa = self.dconv_up3(cPa)
|
| 35 |
+
|
| 36 |
+
# 上采样阶段2
|
| 37 |
+
cPa = self.upsample(cPa)
|
| 38 |
+
cPa = torch.cat([cPa, convs[1]], dim=1) # 连接conv2
|
| 39 |
+
cPa = self.dconv_up2(cPa)
|
| 40 |
+
|
| 41 |
+
# 上采样阶段3
|
| 42 |
+
cPa = self.upsample(cPa)
|
| 43 |
+
cPa = torch.cat([cPa, convs[0]], dim=1) # 连接conv1
|
| 44 |
+
cPa = self.dconv_up1(cPa)
|
| 45 |
+
|
| 46 |
+
# 最终输出
|
| 47 |
+
semi = self.conv_last(cPa)
|
| 48 |
+
return torch.sigmoid(semi)
|
model/net/super_retina_blocks/DoubleConv.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class DoubleConv(nn.Module):
|
| 6 |
+
"""双卷积模块"""
|
| 7 |
+
def __init__(self, in_channels, out_channels):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.conv = nn.Sequential(
|
| 10 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
| 11 |
+
nn.ReLU(inplace=True),
|
| 12 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
| 13 |
+
nn.ReLU(inplace=True)
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return self.conv(x)
|
model/net/super_retina_blocks/SharedEncoder.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SharedEncoder(nn.Module):
|
| 7 |
+
"""共享特征编码器"""
|
| 8 |
+
def __init__(self, channel_list):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
c1, c2, c3, c4, c5, d1, d2 = channel_list
|
| 12 |
+
|
| 13 |
+
self.relu = nn.ReLU(inplace=True)
|
| 14 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 15 |
+
self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
|
| 16 |
+
self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
| 17 |
+
self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
| 18 |
+
self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
| 19 |
+
self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
| 20 |
+
self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
| 21 |
+
self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
| 22 |
+
self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
# 第一层
|
| 26 |
+
x = self.relu(self.conv1a(x))
|
| 27 |
+
conv1 = self.relu(self.conv1b(x))
|
| 28 |
+
x = self.pool(conv1)
|
| 29 |
+
# 第二层
|
| 30 |
+
x = self.relu(self.conv2a(x))
|
| 31 |
+
conv2 = self.relu(self.conv2b(x))
|
| 32 |
+
x = self.pool(conv2)
|
| 33 |
+
# 第三层
|
| 34 |
+
x = self.relu(self.conv3a(x))
|
| 35 |
+
conv3 = self.relu(self.conv3b(x))
|
| 36 |
+
x = self.pool(conv3)
|
| 37 |
+
# 第四层
|
| 38 |
+
x = self.relu(self.conv4a(x))
|
| 39 |
+
x = self.relu(self.conv4b(x))
|
| 40 |
+
|
| 41 |
+
return x, [conv1, conv2, conv3]
|
model/net/super_retina_blocks/__init__.py
ADDED
|
File without changes
|
model/net/super_retina_blocks/__pycache__/DescriptorHead.cpython-310.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
model/net/super_retina_blocks/__pycache__/DescriptorHead_9_0.cpython-310.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
model/net/super_retina_blocks/__pycache__/DetectorHead.cpython-310.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
model/net/super_retina_blocks/__pycache__/DoubleConv.cpython-310.pyc
ADDED
|
Binary file (942 Bytes). View file
|
|
|
model/net/super_retina_blocks/__pycache__/SharedEncoder.cpython-310.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
model/net/super_retina_blocks/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (163 Bytes). View file
|
|
|