很抱歉问了这个非常基本的问题(我是 Keras 的新手)。我想知道 Keras 如何在早期阶段(拟合之前)计算每一层的参数数量,尽管 model.summary 显示在这个阶段有些维度仍然没有值。这些值是否已以某种方式确定?如果是,为什么不在摘要中显示它们?
我问这个问题是因为我很难弄清楚我的“张量形状错误”(我试图确定我的 resnet50 模型的 C5 块的输出尺寸,但我什至无法在 model.summary 中看到它们如果我看到参数的数量)。
我在下面给出了一个基于 RetinaNet 中的 C5_reduced 层的示例,该层由 Resnet50 的 C5 层提供。 C5_reduced 是
Conv2D(256,kernel_size=1,strides=1,pad=1)
基于该特定层的 model.summary:
C5_reduced (Conv2D) (None, None, None, 256) 524544
我猜测 C5 是 (None,1,1,2048),因为 2048*256+256 = 524544 (我不知道如何证实或否定该假设)。因此,如果已经知道,为什么不在摘要中显示它呢?如果维度 2 和维度 3 不同,参数的数量也会不同,对吗?
如果将精确的输入形状传递到网络上的第一层或输入层,您将获得所需的输出。例如,我在这里使用输入层:
input_1 (InputLayer) [(None, 224, 224, 3)] 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
将输入传递为 (224,224,3)。这里的3代表深度。请注意,卷积参数的计算与密集层的计算不同。
如果您执行以下操作:
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(150, 150, 3))
你会看到:
conv2d (Conv2D) ---> (None, 148, 148, 16)
尺寸减小到 148x148,在 Keras 中默认填充为
valid
。另外strides
为1。那么输出的形状将是148 x 148。(你可以搜索公式。)
那么None值是什么?
编辑:
tf.keras.layers.Input(shape = (None, None, 3)),
tf.keras.layers.Conv2D(16, (3,3), activation='relu')
产品:
conv2d_21 (Conv2D) (None, None, None, 16) 448
关于你的问题,即使我们将图像高度和宽度传递为None,参数是如何计算的?
卷积参数计算依据:
(filter_height * filter_width * input_image_channels + 1) * number_of_filters
当我们将它们代入公式时,
filter_height = 3
filter_width = 3
input_image_channel = 3
number_of_filters = 16
参数 = (3 x 3 x 3 + 1) * 16 = 28 * 16 = 448
注意,我们只需要 input_image 的通道号为 3,代表它是 RGB 图像。
如果要计算后面卷积的参数,则需要考虑前一层的滤波器数量将成为当前层通道的通道数。
这就是为什么你最终会得到 None 参数而不是batch_size。在这种情况下,Keras 需要知道您的图像是否为 RGB。或者,您在创建模型时不会指定维度,并且可以在将模型与数据集拟合时传递它们。
您需要为您的模型定义一个输入层。 可训练参数的总数是未知的,直到您 a) 编译模型并向其提供数据,此时模型根据输入的维度制作图表,然后您将能够确定参数的数量,或者b) 您为模型定义一个输入层,并指定输入维度,然后您可以使用 model.summary() 找到参数数量。
关键在于,模型无法知道输入和第一个隐藏层之间的参数数量,直到定义它,或者运行推理并给出输入的形状。
你能解释一下下面公式中的+1吗?卷积滤波器尺寸高度*宽度*通道,我们只是将滤波器与输入图像进行点积(每次滑动)。计算没有偏差。
(filter_height * filter_width * input_image_channels + 1) * number_of_filters