将布尔掩码与 PyTorch 张量中的切片相结合

问题描述 投票:0回答:1

我想用布尔掩码和普通索引来索引 pytorch 张量。像这样的东西:

i = 2
j = 0
mask = torch.randn(480, 360, 3) > 0
tensor = torch.zeros(480, 360, 4, 80)
tensor[mask[..., 0], i, j] = 1

numpy 等效项可以工作,在 pytorch 中它会抛出错误:

IndexError: The shape of the mask [480, 360] at index 1 does not match the shape of the indexed tensor [480, 80] at index 1

有什么想法或提示吗?

pytorch slice tensor
1个回答
0
投票

我遇到了类似的问题,并且我发现使用实际的稀疏索引而不是掩码可能会有所帮助。 具体来说,使用

.nonzero()
定位
True
元素并使用
True
索引进行索引。 请参阅我的问答作为示例:

x = torch.zeros(2, 3, 4, 6)
mask = torch.tensor([[ True, True, False], [True, False, True]])
y = torch.rand(2, 3, 1, 3)
i, j = mask.nonzero(as_tuple = True)
x[i, j, :, :3] = y[i, j]
© www.soinside.com 2019 - 2024. All rights reserved.