我正在使用 PyTorch Lightning,我定义了我的模型,如下所示:
class MyModel(MyBaseClass):
def __init__(self, ..., **kwargs):
super().__init__(**kwargs)
self.model_parameter = nn.Parameter(
torch.rand(...)
)
我使用如下自定义损失函数:
class MyCustomLoss(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, outputs, targets):
loss = ...
scalar_loss = torch.mean(loss)
return scalar_loss
在我的配置文件中,我设置了 class_path,如下所示:
model:
class_path: ...path_to_MyModel
init_args:
criterion:
class_path: ...path_to_MyCustomLoss
但是,我需要一种方法来访问自定义损失函数中的
model_parameter
。我需要这些参数来计算我的损失。如何在自定义损失函数中包含模型参数?
您需要将模型实例传递给您的自定义损失函数。一旦您的自定义损失函数可以访问模型实例,您就可以拉出
model_parameters
。其工作原理如下:
class MyCustomLoss(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, outputs, targets, model):
# access model parameters
model_parameter = model.model_parameter
# proceed to calculate loss...
当您在训练步骤中调用损失函数时,传递模型实例:
class MyModel(MyBaseClass):
def __init__(self, ..., **kwargs):
super().__init__(**kwargs)
self.model_parameter = nn.Parameter(torch.rand(...))
self.criterion = MyCustomLoss(...) # Your custom loss function
def forward(self, x):
# define the forward pass
...
def training_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self(inputs)
# compute loss
loss = self.criterion(outputs, targets, self)
return loss