这是我的 CNN 网络,由“print(model.policy)”打印:
CnnPolicy(
(actor): Actor(
(features_extractor): CustomCNN(
(cnn): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(3): ReLU()
(4): Flatten(start_dim=1, end_dim=-1)
)
(linear): Sequential(
(0): Linear(in_features=6, out_features=128, bias=True)
(1): ReLU()
)
)
(mu): Sequential(
(0): Linear(in_features=128, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=128, bias=True)
(3): ReLU()
(4): Linear(in_features=128, out_features=3, bias=True)
(5): Tanh()
)
)
当我尝试使用 torchsummary.summary(model=model.policy, input_size=(1, 32, 32)) 打印网络架构时。我收到以下错误: 运行时错误:mat1 和 mat2 形状无法相乘(2x50176 和 6x128)
我尝试了很多“input_size”组合,但都是错误的。
我想知道如何选择'input-size'参数?
这不是总结的问题,而是你网络的问题。我认为由于
Flatten()
层及其第二个参数,您对层数感到困惑。
我建议您逐层组装网络,并通过输入随机
x = torch.from_numpy(np.random.rand(batch_dim, channel_dim, spatial1, spatial2)
来测试它,看看它是否可以很好地协同工作。
Flatten 通常用于展平通道和空间维度,但不展平批量维度。您将通道和一个空间维度展平,这可能不是您想要的。
此外,检查您的输入通道是否适合之前的输出通道。如果您提供一个复制粘贴的示例,而不仅仅是结构,我可以调试您的网络。
祝你好运!