Spaces:
Sleeping
Sleeping
File size: 1,563 Bytes
8e92669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class SharedEncoder(nn.Module):
"""共享特征编码器"""
def __init__(self, channel_list):
super().__init__()
c1, c2, c3, c4, c5, d1, d2 = channel_list
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# 第一层
x = self.relu(self.conv1a(x))
conv1 = self.relu(self.conv1b(x))
x = self.pool(conv1)
# 第二层
x = self.relu(self.conv2a(x))
conv2 = self.relu(self.conv2b(x))
x = self.pool(conv2)
# 第三层
x = self.relu(self.conv3a(x))
conv3 = self.relu(self.conv3b(x))
x = self.pool(conv3)
# 第四层
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x))
return x, [conv1, conv2, conv3] |