如何在代码的这一部分中添加跳过连接而不创建用于跳过连接的新类?

问题描述 投票:0回答:1

我想在pytorch中添加skip连接到这段代码,怎么做?

class Branch(nn.Module):
    def __init__(self, channels, strides):
        super(Branch, self).__init__()
        
        self.conv_layer = nn.Sequential()
        for i in range(1, len(channels)):
            self.conv_layer.add_module(f'conv_{i}', nn.Conv2d(in_channels=channels[i-1], out_channels=channels[i], kernel_size=3, stride=strides[i-1], padding=(0, 1))),
            self.conv_layer.add_module(f'bn_{i}', nn.BatchNorm2d(channels[i])),
            self.conv_layer.add_module(f'relu_{i}', nn.ReLU())
            self.conv_layer.add_module(f'conv2_{i}', nn.Conv2d(in_channels=channels[i], out_channels=channels[i], kernel_size=3, padding='same')),
            self.conv_layer.add_module(f'bn2_{i}', nn.BatchNorm2d(channels[i])),
            self.conv_layer.add_module(f'relu2_{i}', nn.ReLU())

        self.projector = nn.LazyLinear(256)
        
    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), x.size(1) * x.size(2), x.size(3)).permute(0, 2, 1)
        x = self.projector(x)
        return x

我尝试搜索,但没有找到我想要的。

python pytorch conv-neural-network
1个回答
0
投票

如果要添加跳过连接,则需要为其创建一个新的

nn.Module
类。您还必须决定在跳过连接中需要哪些层。例如:

class SkipConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride, kernel_size, padding):
        super().__init__()
        
        self.conv = nn.Sequential(
                        nn.Conv2d(in_channels, 
                                  out_channels, 
                                  kernel_size=kernel_size, 
                                  stride=stride, 
                                  padding=padding
                                 ),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x + self.conv(x)
        x = self.relu(x)
        return x

然后,您可以将模型中的 conv/bn/relu 层与

SkipConv
块之一交换。

© www.soinside.com 2019 - 2024. All rights reserved.