检查Flux.jl模型中的参数总数?

问题描述 投票:0回答:2

我知道,使用flux.jl,我可以做

julia> Flux.params(model)
来获取参数,但是输出并未告诉我模型本身中实际存在多少个总参数。是否可以检查此功能或一种计算此方法的方法?

julia flux.jl
2个回答
2
投票

是@mcabbott在评论中提到的,您可以将整个模型传递给参数函数以获取总数(

sum(length, params(model))
)或循环通过每一层,如下所示:

julia> model = Chain(
         resnet[1:end-2],
         Dense(2048, 1000),  
         Dense(1000, 256),
         Dense(256, 2),        # we get 2048 features out, and we have 2 classes
       )
Chain(Chain(Conv((7, 7), 3=>64), MaxPool((3, 3), pad=1, stride=2), Metalhead.ResidualBlock((Conv((1, 1), 64=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), Chain(Conv((1, 1), 64=>256), BatchNorm(256))), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), Chain(Conv((1, 1), 256=>512), BatchNorm(512))), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), Chain(Conv((1, 1), 512=>1024), BatchNorm(1024))), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), Chain(Conv((1, 1), 1024=>2048), BatchNorm(2048))), Metalhead.ResidualBlock((Conv((1, 1), 2048=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), identity), Metalhead.ResidualBlock((Conv((1, 1), 2048=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), identity), MeanPool((7, 7)), #103), Dense(2048, 1000), Dense(1000, 256), Dense(256, 2))

julia> paramCount = 0
0

julia> for layer in model
           paramCount += sum(length, params(layer))
       end

julia> paramCount
25840234

在此示例中,我只是在增加计数,但是您可以将计数从每一层附加到一个数组中,以便单独跟踪每个层的计数。


0
投票
用于磁通0.16,使用

sum(length,Flux.trainables(model))

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.