我想用布尔掩码和普通索引来索引 pytorch 张量。像这样的东西:
i = 2
j = 0
mask = torch.randn(480, 360, 3) > 0
tensor = torch.zeros(480, 360, 4, 80)
tensor[mask[..., 0], i, j] = 1
numpy 等效项可以工作,在 pytorch 中它会抛出错误:
IndexError: The shape of the mask [480, 360] at index 1 does not match the shape of the indexed tensor [480, 80] at index 1
有什么想法或提示吗?
我遇到了类似的问题,并且我发现使用实际的稀疏索引而不是掩码可能会有所帮助。 具体来说,使用
.nonzero()
定位 True
元素并使用 True
索引进行索引。
请参阅我的问答作为示例:
x = torch.zeros(2, 3, 4, 6)
mask = torch.tensor([[ True, True, False], [True, False, True]])
y = torch.rand(2, 3, 1, 3)
i, j = mask.nonzero(as_tuple = True)
x[i, j, :, :3] = y[i, j]