假设我有一批火炬张量
M
形式的图像(B, W, H)
,以及大小为I
的图像(W, H)
,其像素是索引。
我想获取一个图像
(W, H)
,其中每个像素都来自图像批次中的相应图像(遵循I
的索引)。
示例
给定
M
的形状 (3, 4, 8)
:
tensor([[[ 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0.]],
[[-1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1.]],
[[-2., -2., -2., -2., -2., -2., -2., -2.],
[-2., -2., -2., -2., -2., -2., -2., -2.],
[-2., -2., -2., -2., -2., -2., -2., -2.],
[-2., -2., -2., -2., -2., -2., -2., -2.]]])
形状 I
的和
(4, 8)
:
tensor([[2, 0, 2, 0, 1, 0, 1, 0],
[2, 2, 1, 0, 0, 2, 1, 0],
[2, 0, 0, 2, 1, 1, 0, 0],
[0, 1, 0, 0, 2, 0, 2, 1]], dtype=torch.int32)
生成的图像将是:
tensor([[-2., 0., -2., 0., -1., 0., -1., 0.],
[-2., -2., -1., 0., 0., -2., -1., 0.],
[-2., 0., 0., -2., -1., -1., 0., 0.],
[ 0., -1., 0., 0., -2., 0., -2., -1.]])
注1
我不关心
M
尺寸的顺序,如果它提供了更简单的解决方案,它也可以是 (W, H, B)
。
注2
我也对 NumPy 解决方案感兴趣。
一种解决方案是:
indices = torch.meshgrid(torch.arange(I.shape[0]), torch.arange(I.shape[1]))
result = M[I, *indices]
或使用numpy:
indices = np.indices(I)
result = M[I, indices[0], indices[1]]