为什么 pytorch 将网络类视为函数?

问题描述 投票:0回答:1

在查看 PyTorch 教程时,他们使用像函数一样定义的类。

例如

#Making an instance of the class NeuralNetwork
model = NeuralNetwork()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

我不知道这可以在 python 中完成,因此我对发生的事情以及被调用的内容感到困惑。我知道

forward()
函数(从 nn.Module 继承的任何类都必需的函数)是通过神经网络移动数据,但我不确定何时调用该函数以及
model(x)
是如何工作的。另外,在试图找到答案时,我遇到了这个link,其中一位贡献者说“pytorch模型是一个函数。你为它提供适当定义的输入,它返回一个输出。如果你只是想直观地检查给定特定输入图像的输出,只需将其称为:”。这让我更加困惑,因为模型是一个类,对吗?

python deep-learning pytorch
1个回答
0
投票

即使模型是一个类,它也可以实现为像函数一样可调用:

class MyCustomClass:
    def __call__(self, a):
        print(f'Printing {a}')

c = MyCustomClass()
c(42)    # Outputs `Printing 42`
© www.soinside.com 2019 - 2024. All rights reserved.