使用torch.where是否会导致模型参数梯度变为零?

问题描述 投票:0回答:0

这是我的pytorch模型的

forward()
方法:

    def forward(self, x, output_type, *unused_args, **unused_kwargs):
        gru_output, gru_hn = self.gru(x)
        # Decoder (Graph Adjacency Reconstruction)
        for data_batch_idx in range(x.shape[0]):
            pred = self.decoder(gru_output[data_batch_idx, -1, :])  # gru_output[-1] => only take last time-step
            pred_graph_adj = pred.reshape(1, -1) if data_batch_idx == 0 else torch.cat((pred_graph_adj, pred.reshape(1, -1)), dim=0)
        if output_type == "discretize":
            bins = torch.tensor(self.model_cfg['output_bins']).reshape(-1, 1)
            num_bins = len(bins)-1
            bins = torch.concat((bins[:-1], bins[1:]), dim=1)
            discretize_values = np.linspace(-1, 1, num_bins)
            for lower, upper, discretize_value in zip(bins[:, 0], bins[:, 1], discretize_values):
                pred_graph_adj = torch.where((pred_graph_adj <= upper) & (pred_graph_adj > lower), discretize_value, pred_graph_adj)
            pred_graph_adj = torch.where(pred_graph_adj < bins.min(), bins.min(), pred_graph_adj)

        return pred_graph_adj

这是训练的片段:

                pred = self.forward(x, output_type=self.model_cfg['output_type'])
                batch_loss = self.loss_fn(pred, y)
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                self.scheduler.step()
  1. output_type
    不是“离散化”时
    (not using
    torch.where
    ), 
    sum([p.grad.sum() for p in self.decoder.parameters()])`将为非零。
    • 但是当
      output_type
      为“离散化”时
      (using
      torch.where
      ), 
      sum([p.grad.sum() for p in self.decoder.parameters()])` 将为零。
  2. 我检查过
    batch_loss
    ,它不是零。
  3. 我已经检查了模型重量的所有
    require_grad
    ,它们都是真实的。
  4. 我检查了计算图,
    pred
    batch_loss
    与模型的权重有关。

我的问题是:

  1. 使用
    torch.where
    是否会导致模型的参数梯度变为零?
  2. 如果
    torch.where
    不会造成这种情况,还有其他可能的原因吗?
python machine-learning deep-learning pytorch
© www.soinside.com 2019 - 2024. All rights reserved.