为什么我的修剪模型的文件大小比初始模型大?

问题描述 投票:0回答:1

我正在使用this示例探索修剪神经网络。我的修剪代码使用预先训练的模型,如下所示:

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 64
epochs = 3
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = 114 * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity = 0.50,
                                                               final_sparsity = 0.80,
                                                               begin_step = 0,
                                                               end_step = end_step)
}

pruned_model = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
pruned_model.compile(optimizer = 'adam', loss = keras.losses.SparseCategoricalCrossentropy(from_logits = True), metrics = ['accuracy'])

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir = logdir),
]

pruned_model.fit(train_dataset, batch_size=batch_size, epochs=epochs, validation_data=valid_dataset, callbacks=callbacks)

然后,使用

pruned_model.evaluate(train_dataset, verbose=0)
,我发现它的准确度确实有所下降,正如预期的那样 - 我上次运行测试时得到了这些结果:

Baseline test accuracy: 0.9197102189064026
Pruned model test accuracy: 0.8976686000823975

我一直在使用

model.save()
保存初始
model
并将
pruned_model
修剪为 .h5 和 .keras 格式;这些分别达到 60.7 到 60.9 MB。然而,修剪后的网络在保存为 .h5 时大小为 85.4 MB,而 .keras 版本更大,为 110 MB。我在 keras 文档中找不到任何有关保存文件时需要指定优化的内容 - 只有目录。

python tensorflow keras neural-network pruning
1个回答
0
投票

由于序列化开销(修剪模型包括用于重建模型架构的额外元数据)、稀疏矩阵存储(某些存储格式不能有效压缩稀疏矩阵,并且存储它们的效率可能低于密集矩阵),修剪后的模型最终可能会变得更大。矩阵),检查点信息(包括剪枝状态)。

要解决此问题,您可以使用高效的存储格式,例如

.tflite
带量化:

converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('quantized_model.tflite', 'wb') as f:
    f.write(tflite_model)

和/或条状修剪专用包装纸,例如使用

strip_pruning
(请参阅文档)。

© www.soinside.com 2019 - 2024. All rights reserved.