Hongyang Li
Upload 13 files
8e92669 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class SharedEncoder(nn.Module):
"""共享特征编码器"""
def __init__(self, channel_list):
super().__init__()
c1, c2, c3, c4, c5, d1, d2 = channel_list
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# 第一层
x = self.relu(self.conv1a(x))
conv1 = self.relu(self.conv1b(x))
x = self.pool(conv1)
# 第二层
x = self.relu(self.conv2a(x))
conv2 = self.relu(self.conv2b(x))
x = self.pool(conv2)
# 第三层
x = self.relu(self.conv3a(x))
conv3 = self.relu(self.conv3b(x))
x = self.pool(conv3)
# 第四层
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x))
return x, [conv1, conv2, conv3]