从.pth模型获取Pytorch激活函数

问题描述 投票:0回答:1

我正在寻找一种方法来检索保存为 .pth 文件的 PyTorch 网络中使用的激活函数 (

torch.save(model)
)。事实上,如果在创建模型时未在类中声明激活函数,而是仅在前向方法中声明激活函数,我将无法识别这些激活函数 例如:


class SimpleCNN_2(nn.Module):
    def __init__(self):
        super(SimpleCNN_2, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(in_features=32768, out_features=128)  
        self.fc2 = nn.Linear(in_features=128, out_features=10)  

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = nn.Flatten()(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

我得到这样的描述:

SimpleCNN_2(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=32768, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

如果声明了激活函数,我就没有任何问题。我担心的是,我想从 .pth 文件评估网络,如果我无法获取网络的整个结构,那就一团糟。

我希望有一种方法,可以让我从 .pth 文件中以列表的形式逐层获取网络的完整结构,而无需提前知道网络的结构。

pytorch neural-network activation-function
1个回答
0
投票

我认为问题在于激活函数没有任何序列化的参数,所以你不会在

pth
文件中找到它。通常在 pytorch 中,您需要模型,您将需要模型定义来获取所有内容。一种解决方法可能是使用Sequential。我想这也应该存储激活函数。

© www.soinside.com 2019 - 2024. All rights reserved.