据我所知,在 pytorch 中有两种表达 ResNet Block 的方法:
这导致了两种代码:
def forward(self, x):
y = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.norm2(x)
x += y
x = self.act2(x)
return x
def forward(self, x):
y = self.conv1(x)
y = self.norm1(y)
y = self.act1(y)
y = self.conv2(y)
y = self.norm2(y)
y += x
y = self.act2(y)
return y
它们相同吗?哪一个是首选?为什么?
只要您保留对输入的一些引用就没关系。
在较高的层面上,你正在尝试计算
output = activation(input + f(input))
所示的两种方法都可以实现这一点。只要您不丢失
input
引用或通过就地操作更改 input
,就应该没问题。
为了清楚起见,我会将剩余连接和子块分开:
class Block(nn.Module):
def __init__(self, ...):
super().__init__()
self.conv1 = ...
self.norm1 = ...
self.act = ...
self.conv2 = ...
self.norm2 = ...
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class ResBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block
self.act = ...
def forward(self, x):
return self.act(x + self.block(x))