我对以下代码片段有一些疑问:
>>> def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
apply()是pytorch.nn包的一部分。您可以在此软件包的文档中找到该代码。最后的问题:1。为什么这个代码示例有效,尽管在给予apply()时init_weights()中没有添加任何参数或括号? 2.函数init_weights(m)从哪里得到它的参数m,当它作为参数给出函数apply()而没有括号和m?
我们在torch.nn.Module.apply(fn)
的文档中找到了您的问题的答案:
将
fn
递归地应用于每个子模块(由.children()返回)以及self。典型用途包括初始化模型的参数(另请参见torch-nn-init)。
init_weights
在apply
调用之前不被调用,正是因为没有括号,而是对init_weights
给apply
的引用,并且仅在apply
之后的init_weights
中被称为。apply
中的每个调用得到它的参数,并且,正如文档所述,由于方法调用net
,它被调用m迭代(在这种情况下)net
以及net.apply(…)
本身的每个子模块。