如何在我的自定义损失函数中包含模型的参数

问题描述 投票:0回答:1

我正在使用 PyTorch Lightning,我定义了我的模型,如下所示:

class MyModel(MyBaseClass):

    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)

        self.model_parameter = nn.Parameter(
            torch.rand(...) 
        )

我使用如下自定义损失函数:

class MyCustomLoss(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, outputs, targets):
        loss = ...
        scalar_loss = torch.mean(loss)
        return scalar_loss

在我的配置文件中,我设置了 class_path,如下所示:

model:
    class_path: ...path_to_MyModel
    init_args:
    criterion:
        class_path: ...path_to_MyCustomLoss

但是,我需要一种方法来访问自定义损失函数中的

model_parameter
。我需要这些参数来计算我的损失。如何在自定义损失函数中包含模型参数?

deep-learning pytorch loss-function gradient-descent pytorch-lightning
1个回答
0
投票

您需要将模型实例传递给您的自定义损失函数。一旦您的自定义损失函数可以访问模型实例,您就可以拉出

model_parameters
。其工作原理如下:

class MyCustomLoss(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, outputs, targets, model):
        # access model parameters
        model_parameter = model.model_parameter
        # proceed to calculate loss...
    

当您在训练步骤中调用损失函数时,传递模型实例:

class MyModel(MyBaseClass):

    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)
        self.model_parameter = nn.Parameter(torch.rand(...))
        self.criterion = MyCustomLoss(...)  # Your custom loss function

    def forward(self, x):
        # define the forward pass
        ...

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        # compute loss
        loss = self.criterion(outputs, targets, self)
        return loss
© www.soinside.com 2019 - 2024. All rights reserved.