这里有一些Python代码来重现我的问题:
import torch
n, m = 9, 4
x = torch.arange(0, n * m).reshape(n, m)
print(x.shape)
print(x)
# torch.Size([9, 4])
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23],
# [24, 25, 26, 27],
# [28, 29, 30, 31],
# [32, 33, 34, 35]])
list_of_indices = [
[],
[2, 3],
[1],
[],
[],
[],
[0, 1, 2, 3],
[],
[0, 3],
]
print(list_of_indices)
for i, indices in enumerate(list_of_indices):
x[i, indices] = -1
print(x)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, -1, -1],
# [ 8, -1, 10, 11],
# [12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23],
# [-1, -1, -1, -1],
# [28, 29, 30, 31],
# [-1, 33, 34, -1]])
我有一个索引列表。我想使用
x
中的索引将 -1
中的索引设置为特定值(此处为 list_of_indices
)。在此列表中,每个子列表对应一行 x
,包含要设置为该行 -1
的索引。这可以使用 for 循环轻松完成,但我觉得 pytorch 可以更有效地做到这一点。
我尝试了以下方法:
x[torch.arange(len(list_of_indices)), list_of_indices] = -1
但结果是
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [9, 0]
我试图找到有同样问题的人,但是有关索引张量的问题数量如此之多,以至于我可能错过了。
这是因为
list_of_indices
是一个参差不齐的 list
(即它包含空嵌套 []
),所以如果我们包含一个返回 tensor
的函数,则与 shape
相同的 x
,其中 1
s 是来自 indices
的 list_of_indices
(0
s 是不在 list_of_indices
中的索引),那么我们可以将其输入到 torch.where
索引中 x
:
def get_indices_from_list(list_of_indices):
def fill_list(f):
_f = torch.zeros(4).long(); _f[f] = 1
return _f
return torch.stack([fill_list(i) for i in list_of_indices])
x[torch.where(get_indices_from_list(list_of_indices) == 1)] = -1
print(x)
输出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, -1, -1],
[ 8, -1, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23],
[-1, -1, -1, -1],
[28, 29, 30, 31],
[-1, 33, 34, -1]])