我想在pytorch中添加skip连接到这段代码,怎么做?
class Branch(nn.Module):
def __init__(self, channels, strides):
super(Branch, self).__init__()
self.conv_layer = nn.Sequential()
for i in range(1, len(channels)):
self.conv_layer.add_module(f'conv_{i}', nn.Conv2d(in_channels=channels[i-1], out_channels=channels[i], kernel_size=3, stride=strides[i-1], padding=(0, 1))),
self.conv_layer.add_module(f'bn_{i}', nn.BatchNorm2d(channels[i])),
self.conv_layer.add_module(f'relu_{i}', nn.ReLU())
self.conv_layer.add_module(f'conv2_{i}', nn.Conv2d(in_channels=channels[i], out_channels=channels[i], kernel_size=3, padding='same')),
self.conv_layer.add_module(f'bn2_{i}', nn.BatchNorm2d(channels[i])),
self.conv_layer.add_module(f'relu2_{i}', nn.ReLU())
self.projector = nn.LazyLinear(256)
def forward(self, x):
x = self.conv_layer(x)
x = x.view(x.size(0), x.size(1) * x.size(2), x.size(3)).permute(0, 2, 1)
x = self.projector(x)
return x
我尝试搜索,但没有找到我想要的。
如果要添加跳过连接,则需要为其创建一个新的
nn.Module
类。您还必须决定在跳过连接中需要哪些层。例如:
class SkipConv(nn.Module):
def __init__(self, in_channels, out_channels, stride, kernel_size, padding):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.relu = nn.ReLU()
def forward(self, x):
x = x + self.conv(x)
x = self.relu(x)
return x
然后,您可以将模型中的 conv/bn/relu 层与
SkipConv
块之一交换。