BELOW是来自我的NTM控制器的简化片段(实际模块更复杂,涉及2D注意图的卷积融合和随后的MLP处理):
class NTMController(nn.Module):
"""
A fully advanced NTM controller that fuses 2D attention maps with the base representation.
Produces control signals (read/write keys, etc.) via an MLP.
"""
def __init__(self, d_model=128, mem_dim=128, hidden_dim=256, n_layers=3,
fuse_in_channels=32, fuse_out_channels=32):
super().__init__()
self.d_model = d_model
self.mem_dim = mem_dim
# Layers for 2D attention fusion:
self.conv_merge = nn.Conv2d(fuse_in_channels, fuse_out_channels, kernel_size=3, padding=1)
self.resblock = BasicResBlock2D(fuse_out_channels, fuse_out_channels, stride=1)
self.final_linear = nn.Linear(fuse_out_channels, fuse_out_channels)
# MLP for generating control signals; input is base representation concatenated with fused 2D features.
mlp_in_dim = d_model + fuse_out_channels
layers = []
in_dim = mlp_in_dim
for _ in range(n_layers):
layers.append(nn.Linear(in_dim, hidden_dim))
layers.append(nn.ReLU())
in_dim = hidden_dim
# Final layer produces 4*mem_dim + 1 outputs.
layers.append(nn.Linear(hidden_dim, 4 * mem_dim + 1))
self.mlp = nn.Sequential(*layers)
def _fuse_2d_maps(self, attn_intra: Optional[torch.Tensor], attn_hier: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if attn_intra is None and attn_hier is None:
return None
# If both maps are provided, interpolate to the same spatial size and concatenate along the channel dim.
if attn_intra is not None and attn_hier is not None:
B, C_i, Hi, Wi = attn_intra.shape
B2, C_h, Hh, Wh = attn_hier.shape
Hmax, Wmax = max(Hi, Hh), max(Wi, Wh)
if (Hi, Wi) != (Hmax, Wmax):
attn_intra = F.interpolate(attn_intra, size=(Hmax, Wmax), mode='bilinear', align_corners=False)
if (Hh, Wh) != (Hmax, Wmax):
attn_hier = F.interpolate(attn_hier, size=(Hmax, Wmax), mode='bilinear', align_corners=False)
x = torch.cat([attn_intra, attn_hier], dim=1)
else:
x = attn_intra if attn_intra is not None else attn_hier
x = F.relu(self.conv_merge(x))
x = self.resblock(x)
B, C, H, W = x.shape
x_pool = F.adaptive_avg_pool2d(x, (1, 1)).view(B, C)
return self.final_linear(x_pool)
def forward(self, reps: torch.Tensor,
attn_hier: Optional[torch.Tensor] = None,
attn_intra: Optional[torch.Tensor] = None) -> tuple:
# Fuse attention maps into a feature vector.
fused_2d = self._fuse_2d_maps(attn_intra, attn_hier)
if fused_2d is None:
fused_2d = torch.zeros(reps.size(0), self.final_linear.out_features, device=reps.device)
# Concatenate the base representation with the fused 2D features.
cat_inp = torch.cat([reps, fused_2d], dim=-1)
# Use clone() here to avoid potential in-place modifications:
out = self.mlp(cat_inp.clone()) # <--- is where the traceback points
# Split the output into control signals.
read_key, write_key, erase_raw, add_vec, scale_raw = torch.split(
out, [self.mem_dim, self.mem_dim, self.mem_dim, self.mem_dim, 1], dim=-1
)
scale = torch.sigmoid(scale_raw).squeeze(-1)
# For now, duplicate scale for read and write operations.
return read_key, write_key, erase_raw, add_vec, scale, scale
I还具有一个使用这些控制信号执行读/写操作的内存模块。该错误在最终输出上呼叫loss.backward()时,在ntmcontroller.forward()中专门发生在self.mlp(cat_inp.clone())中。为了澄清,我只会一次拨打lose.backward(),所以不要两次。 我的问题是:
1. What might be causing the computation graph to be traversed twice?
2. Are there any common pitfalls in combining cloned tensors, detached tensors, or using in-place operations that could lead to this error?
3. How can I restructure my code so that each forward pass only builds a single graph for the backward pass, without resorting to retain_graph=True?
任何洞察力或有关进一步诊断此问题的建议将不胜感激!
在我重新设计采样方法后,解决了这一错误(请参见
Pytorch检查点错误:重新计算的张量元数据不匹配在全局表示中,并具有额外的采样
);这与我没有通过向前通行证计算单独的全局表示的事实有关。我从来没有根据我得到的追溯来弄清楚的事情,因为这没有提到与代码的采样部分有关。