这是我的pytorch模型的
forward()
方法:
def forward(self, x, output_type, *unused_args, **unused_kwargs):
gru_output, gru_hn = self.gru(x)
# Decoder (Graph Adjacency Reconstruction)
for data_batch_idx in range(x.shape[0]):
pred = self.decoder(gru_output[data_batch_idx, -1, :]) # gru_output[-1] => only take last time-step
pred_graph_adj = pred.reshape(1, -1) if data_batch_idx == 0 else torch.cat((pred_graph_adj, pred.reshape(1, -1)), dim=0)
if output_type == "discretize":
bins = torch.tensor(self.model_cfg['output_bins']).reshape(-1, 1)
num_bins = len(bins)-1
bins = torch.concat((bins[:-1], bins[1:]), dim=1)
discretize_values = np.linspace(-1, 1, num_bins)
for lower, upper, discretize_value in zip(bins[:, 0], bins[:, 1], discretize_values):
pred_graph_adj = torch.where((pred_graph_adj <= upper) & (pred_graph_adj > lower), discretize_value, pred_graph_adj)
pred_graph_adj = torch.where(pred_graph_adj < bins.min(), bins.min(), pred_graph_adj)
return pred_graph_adj
这是训练的片段:
pred = self.forward(x, output_type=self.model_cfg['output_type'])
batch_loss = self.loss_fn(pred, y)
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
self.scheduler.step()
output_type
不是“离散化”时(not using
torch.where),
sum([p.grad.sum() for p in self.decoder.parameters()])`将为非零。
output_type
为“离散化”时(using
torch.where),
sum([p.grad.sum() for p in self.decoder.parameters()])` 将为零。batch_loss
,它不是零。require_grad
,它们都是真实的。pred
和batch_loss
与模型的权重有关。我的问题是:
torch.where
是否会导致模型的参数梯度变为零?torch.where
不会造成这种情况,还有其他可能的原因吗?