我正在研究修剪对 CoreML 模型的影响。
虽然我可以轻松测量文件大小的变化(以 kB 为单位),但我很难估计模型参数数量的变化,因为 CoreML 不提供像 PyTorch 那样的直接方法
model.parameters()
。如何估计或计算修剪后的 CoreML 模型中的参数数量?
除了迭代各层并根据层的类型计算参数之外,似乎没有直接且通用的方法来执行此操作。
例如,
存储带有卷积层和线性层的 CoreML 模型
import torch
import torch.nn as nn
import coremltools as ct
import os
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(16 * 32 * 32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
model = SimpleNet()
dummy_input = torch.rand(1, 3, 32, 32)
traced_model = torch.jit.trace(model, dummy_input)
coreml_model = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)]
)
model_name = "simple_net"
coreml_model.save(f"{model_name}.mlpackage")
该模型有两层:卷积层和线性层。因此,我们只需使用
model.get_spec().neuralNetwork.layers
迭代各层,对于卷积层,kernel_channels * output_channels * kernel_height * kernel_width
给出权重,output_channels
是偏差项的数量。对于线性层来说,它只是输入数量乘以输出数量。
def count_parameters(model):
total_params = 0
for layer in model.get_spec().neuralNetwork.layers:
if layer.HasField('convolution'):
conv = layer.convolution
kernel_channels = conv.kernelChannels
output_channels = conv.outputChannels
kernel_height = conv.kernelSize[0]
kernel_width = conv.kernelSize[1]
params = kernel_channels * output_channels * kernel_height * kernel_width
if conv.hasBias:
params += output_channels
total_params += params
elif layer.HasField('innerProduct'):
inner = layer.innerProduct
input_channels = inner.inputChannels
output_channels = inner.outputChannels
params = input_channels * output_channels
if inner.hasBias:
params += output_channels
total_params += params
return total_params
coreml_params = count_parameters(coreml_model)
print(f"Total parameters: {coreml_params}")
打印出来了
Number of parameters in PyTorch model: 164298
如果直接在Pytorch模型上通过
来算total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
很匹配
Total parameters: 164298
不是最好的方法(我希望其他人有更通用的方法),但这是一个开始。