Hongyang Li commited on
Commit
8e92669
·
verified ·
1 Parent(s): 449cb9f

Upload 13 files

Browse files
model/net/super_retina.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 这部分是解耦之后重构的SuperRetina模型
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from .super_retina_blocks.SharedEncoder import SharedEncoder
7
+ from .super_retina_blocks.DetectorHead import DetectorHead
8
+ from .super_retina_blocks.DescriptorHead import DescriptorHead
9
+
10
+
11
+ class SuperRetina(nn.Module):
12
+ def __init__(self, config=None, encoder_device='cpu', detector_device='cpu', descriptor_device='cpu', n_class=1):
13
+ super().__init__()
14
+
15
+ self.channel_list = config["channel_list"]
16
+
17
+ self.encoder = SharedEncoder(self.channel_list)
18
+ self.desc_head = DescriptorHead(self.channel_list)
19
+ self.det_head = DetectorHead(self.channel_list, n_class)
20
+ self.encoder_device = encoder_device
21
+ self.detector_device = detector_device
22
+ self.descriptor_device = descriptor_device
23
+
24
+ # 移动到对应设备上
25
+ self.encoder.to(self.encoder_device)
26
+ self.det_head.to(self.detector_device)
27
+ self.desc_head.to(self.descriptor_device)
28
+
29
+ if config is not None:
30
+ self.config = config # 传递config参数
31
+ self.nms_size = config['nms_size']
32
+ self.nms_thresh = config['nms_thresh']
33
+ self.scale = 8
34
+ # self.dice = DiceLoss()
35
+ # self.kernel = get_gaussian_kernel(kernlen=config['gaussian_kernel_size'], nsig=config['gaussian_sigma']).to(device)
36
+
37
+ def forward(self, x): # 定义模型的前向传播过程
38
+ # 特征编码
39
+ enc_out, convs = self.encoder(x.to(self.encoder_device))
40
+
41
+ # 关键点检测
42
+ semi = self.det_head(enc_out.to(self.detector_device), [item.to(self.detector_device) for item in convs])
43
+
44
+ # 描述符生成
45
+ desc = self.desc_head(enc_out.to(self.descriptor_device))
46
+
47
+ return semi, desc
model/net/super_retina_blocks/DescriptorHead.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DescriptorHead(nn.Module):
7
+ """描述符生成头"""
8
+ def __init__(self, channel_list):
9
+ super().__init__()
10
+
11
+ c1, c2, c3, c4, c5, d1, d2 = channel_list
12
+
13
+ self.relu = nn.ReLU(inplace=True)
14
+ self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 尺寸不变
15
+ # ---------------------------------------------------------------------------------------
16
+ # 原本的代码这里尺寸有点问题,96下采样过后是47不是48,需要padding=1
17
+ # self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=0) # 尺寸缩小1/2
18
+ self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=1) # 尺寸缩小1/2
19
+ # ---------------------------------------------------------------------------------------
20
+ self.convDc = torch.nn.Conv2d(d1, d2, kernel_size=1, stride=1, padding=0) # 尺寸不变
21
+ self.trans_conv = nn.ConvTranspose2d(d1, d2, 2, stride=2) # 尺寸放大2倍
22
+
23
+ def forward(self, x):
24
+ cDa = self.relu(self.convDa(x))
25
+ cDb = self.relu(self.convDb(cDa))
26
+ desc = self.convDc(cDb)
27
+
28
+ # 归一化描述符
29
+ dn = torch.norm(desc, p=2, dim=1)
30
+ desc = desc.div(torch.unsqueeze(dn, 1))
31
+
32
+ return self.trans_conv(desc)
model/net/super_retina_blocks/DescriptorHead_9_0.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DescriptorHead(nn.Module):
7
+ """描述符生成头"""
8
+ def __init__(self, channel_list):
9
+ super().__init__()
10
+
11
+ c1, c2, c3, c4, c5, d1, d2 = channel_list
12
+
13
+ self.relu = nn.ReLU(inplace=True)
14
+ self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 尺寸不变
15
+ # ---------------------------------------------------------------------------------------
16
+ # 原本的代码这里尺寸有点问题,96下采样过后是47不是48,需要padding=1
17
+ # self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=0) # 尺寸缩小1/2
18
+ self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=4, stride=2, padding=1) # 尺寸缩小1/2
19
+ # ---------------------------------------------------------------------------------------
20
+ self.convDc = torch.nn.Conv2d(d1, d2, kernel_size=1, stride=1, padding=0) # 尺寸不变
21
+ # self.trans_conv = nn.ConvTranspose2d(d1, d2, 2, stride=2) # 尺寸放大2倍
22
+
23
+ def forward(self, x):
24
+ cDa = self.relu(self.convDa(x))
25
+ cDb = self.relu(self.convDb(cDa))
26
+ desc = self.convDc(cDb)
27
+
28
+ # 归一化描述符
29
+ dn = torch.norm(desc, p=2, dim=1)
30
+ desc = desc.div(torch.unsqueeze(dn, 1))
31
+
32
+ # return self.trans_conv(desc)
33
+ return desc
model/net/super_retina_blocks/DetectorHead.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .DoubleConv import DoubleConv
6
+
7
+ class DetectorHead(nn.Module):
8
+ """关键点检测头部"""
9
+ def __init__(self, channel_list, n_class):
10
+ """
11
+ in_channels_list: 各层输入通道数 [c3+c4, c2+c3, c1+c2]
12
+ """
13
+ super().__init__()
14
+
15
+ c1, c2, c3, c4, c5, d1, d2 = channel_list
16
+
17
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
18
+ # self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
19
+
20
+ # 上采样路径
21
+ self.dconv_up3 = DoubleConv(c3 + c4, c3)
22
+ self.dconv_up2 = DoubleConv(c2 + c3, c2)
23
+ self.dconv_up1 = DoubleConv(c1 + c2, c1)
24
+ self.conv_last = nn.Conv2d(c1, n_class, kernel_size=1)
25
+
26
+ def forward(self, x, convs):
27
+ """
28
+ x: 编码器输出特征
29
+ convs: 中间特征 [conv1, conv2, conv3]
30
+ """
31
+ # 上采样阶段1
32
+ cPa = self.upsample(x)
33
+ cPa = torch.cat([cPa, convs[2]], dim=1) # 连接conv3
34
+ cPa = self.dconv_up3(cPa)
35
+
36
+ # 上采样阶段2
37
+ cPa = self.upsample(cPa)
38
+ cPa = torch.cat([cPa, convs[1]], dim=1) # 连接conv2
39
+ cPa = self.dconv_up2(cPa)
40
+
41
+ # 上采样阶段3
42
+ cPa = self.upsample(cPa)
43
+ cPa = torch.cat([cPa, convs[0]], dim=1) # 连接conv1
44
+ cPa = self.dconv_up1(cPa)
45
+
46
+ # 最终输出
47
+ semi = self.conv_last(cPa)
48
+ return torch.sigmoid(semi)
model/net/super_retina_blocks/DoubleConv.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DoubleConv(nn.Module):
6
+ """双卷积模块"""
7
+ def __init__(self, in_channels, out_channels):
8
+ super().__init__()
9
+ self.conv = nn.Sequential(
10
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
13
+ nn.ReLU(inplace=True)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.conv(x)
model/net/super_retina_blocks/SharedEncoder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SharedEncoder(nn.Module):
7
+ """共享特征编码器"""
8
+ def __init__(self, channel_list):
9
+ super().__init__()
10
+
11
+ c1, c2, c3, c4, c5, d1, d2 = channel_list
12
+
13
+ self.relu = nn.ReLU(inplace=True)
14
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
15
+ self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
16
+ self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
17
+ self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
18
+ self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
19
+ self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
20
+ self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
21
+ self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
22
+ self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
23
+
24
+ def forward(self, x):
25
+ # 第一层
26
+ x = self.relu(self.conv1a(x))
27
+ conv1 = self.relu(self.conv1b(x))
28
+ x = self.pool(conv1)
29
+ # 第二层
30
+ x = self.relu(self.conv2a(x))
31
+ conv2 = self.relu(self.conv2b(x))
32
+ x = self.pool(conv2)
33
+ # 第三层
34
+ x = self.relu(self.conv3a(x))
35
+ conv3 = self.relu(self.conv3b(x))
36
+ x = self.pool(conv3)
37
+ # 第四层
38
+ x = self.relu(self.conv4a(x))
39
+ x = self.relu(self.conv4b(x))
40
+
41
+ return x, [conv1, conv2, conv3]
model/net/super_retina_blocks/__init__.py ADDED
File without changes
model/net/super_retina_blocks/__pycache__/DescriptorHead.cpython-310.pyc ADDED
Binary file (1.3 kB). View file
 
model/net/super_retina_blocks/__pycache__/DescriptorHead_9_0.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
model/net/super_retina_blocks/__pycache__/DetectorHead.cpython-310.pyc ADDED
Binary file (1.56 kB). View file
 
model/net/super_retina_blocks/__pycache__/DoubleConv.cpython-310.pyc ADDED
Binary file (942 Bytes). View file
 
model/net/super_retina_blocks/__pycache__/SharedEncoder.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
model/net/super_retina_blocks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file