我正在寻找一种方法来检索保存为 .pth 文件的 PyTorch 网络中使用的激活函数 (
torch.save(model)
)。事实上,如果在创建模型时未在类中声明激活函数,而是仅在前向方法中声明激活函数,我将无法识别这些激活函数
例如:
class SimpleCNN_2(nn.Module):
def __init__(self):
super(SimpleCNN_2, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(in_features=32768, out_features=128)
self.fc2 = nn.Linear(in_features=128, out_features=10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = nn.Flatten()(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
我得到这样的描述:
SimpleCNN_2(
(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1): Linear(in_features=32768, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
如果声明了激活函数,我就没有任何问题。我担心的是,我想从 .pth 文件评估网络,如果我无法获取网络的整个结构,那就一团糟。
我希望有一种方法,可以让我从 .pth 文件中以列表的形式逐层获取网络的完整结构,而无需提前知道网络的结构。
我认为问题在于激活函数没有任何序列化的参数,所以你不会在
pth
文件中找到它。通常在 pytorch 中,您需要模型,您将需要模型定义来获取所有内容。一种解决方法可能是使用Sequential。我想这也应该存储激活函数。