luntimeError:尝试第二次通过图形向后 - 如何在不使用retain_graph = true的情况下解决?

问题描述 投票:0回答:0
我注意到此错误发生在NTM控制器MLP的向后通过。我最初试图通过使用retain_graph = true来解决此问题,但这很快就会炸毁我的内存。我想解决基本问题,以便模型正确计算梯度而无需保留整个图。

BELOW是来自我的NTM控制器的简化片段(实际模块更复杂,涉及2D注意图的卷积融合和随后的MLP处理):

class NTMController(nn.Module):
"""
A fully advanced NTM controller that fuses 2D attention maps with the base representation.
Produces control signals (read/write keys, etc.) via an MLP.
"""
def __init__(self, d_model=128, mem_dim=128, hidden_dim=256, n_layers=3,
             fuse_in_channels=32, fuse_out_channels=32):
    super().__init__()
    self.d_model = d_model
    self.mem_dim = mem_dim

    # Layers for 2D attention fusion:
    self.conv_merge = nn.Conv2d(fuse_in_channels, fuse_out_channels, kernel_size=3, padding=1)
    self.resblock = BasicResBlock2D(fuse_out_channels, fuse_out_channels, stride=1)
    self.final_linear = nn.Linear(fuse_out_channels, fuse_out_channels)

    # MLP for generating control signals; input is base representation concatenated with fused 2D features.
    mlp_in_dim = d_model + fuse_out_channels
    layers = []
    in_dim = mlp_in_dim
    for _ in range(n_layers):
        layers.append(nn.Linear(in_dim, hidden_dim))
        layers.append(nn.ReLU())
        in_dim = hidden_dim
    # Final layer produces 4*mem_dim + 1 outputs.
    layers.append(nn.Linear(hidden_dim, 4 * mem_dim + 1))
    self.mlp = nn.Sequential(*layers)

def _fuse_2d_maps(self, attn_intra: Optional[torch.Tensor], attn_hier: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
    if attn_intra is None and attn_hier is None:
        return None
    # If both maps are provided, interpolate to the same spatial size and concatenate along the channel dim.
    if attn_intra is not None and attn_hier is not None:
        B, C_i, Hi, Wi = attn_intra.shape
        B2, C_h, Hh, Wh = attn_hier.shape
        Hmax, Wmax = max(Hi, Hh), max(Wi, Wh)
        if (Hi, Wi) != (Hmax, Wmax):
            attn_intra = F.interpolate(attn_intra, size=(Hmax, Wmax), mode='bilinear', align_corners=False)
        if (Hh, Wh) != (Hmax, Wmax):
            attn_hier = F.interpolate(attn_hier, size=(Hmax, Wmax), mode='bilinear', align_corners=False)
        x = torch.cat([attn_intra, attn_hier], dim=1)
    else:
        x = attn_intra if attn_intra is not None else attn_hier

    x = F.relu(self.conv_merge(x))
    x = self.resblock(x)
    B, C, H, W = x.shape
    x_pool = F.adaptive_avg_pool2d(x, (1, 1)).view(B, C)
    return self.final_linear(x_pool)

def forward(self, reps: torch.Tensor,
            attn_hier: Optional[torch.Tensor] = None,
            attn_intra: Optional[torch.Tensor] = None) -> tuple:
    # Fuse attention maps into a feature vector.
    fused_2d = self._fuse_2d_maps(attn_intra, attn_hier)
    if fused_2d is None:
        fused_2d = torch.zeros(reps.size(0), self.final_linear.out_features, device=reps.device)
    # Concatenate the base representation with the fused 2D features.
    cat_inp = torch.cat([reps, fused_2d], dim=-1)
    # Use clone() here to avoid potential in-place modifications:
    out = self.mlp(cat_inp.clone())    # <--- is where the traceback points
    # Split the output into control signals.
    read_key, write_key, erase_raw, add_vec, scale_raw = torch.split(
        out, [self.mem_dim, self.mem_dim, self.mem_dim, self.mem_dim, 1], dim=-1
    )
    scale = torch.sigmoid(scale_raw).squeeze(-1)
    # For now, duplicate scale for read and write operations.
    return read_key, write_key, erase_raw, add_vec, scale, scale

I还具有一个使用这些控制信号执行读/写操作的内存模块。该错误在最终输出上呼叫loss.backward()时,在ntmcontroller.forward()中专门发生在self.mlp(cat_inp.clone())中。为了澄清,我只会一次拨打lose.backward(),所以不要两次。 我的问题是:

1. What might be causing the computation graph to be traversed twice? 2. Are there any common pitfalls in combining cloned tensors, detached tensors, or using in-place operations that could lead to this error? 3. How can I restructure my code so that each forward pass only builds a single graph for the backward pass, without resorting to retain_graph=True?

任何洞察力或有关进一步诊断此问题的建议将不胜感激!
    

在我重新设计采样方法后,解决了这一错误(请参见

Pytorch检查点错误:重新计算的张量元数据不匹配在全局表示中,并具有额外的采样

);这与我没有通过向前通行证计算单独的全局表示的事实有关。
我从来没有根据我得到的追溯来弄清楚的事情,因为这没有提到与代码的采样部分有关。
    

pytorch neural-network runtime-error backpropagation autograd
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.