使用 PyTorch 张量从索引图像索引一批图像

问题描述 投票:0回答:1

假设我有一批火炬张量

M
形式的图像
(B, W, H)
,以及大小为
I
的图像
(W, H)
,其像素是索引。

我想获取一个图像

(W, H)
,其中每个像素都来自图像批次中的相应图像(遵循
I
的索引)。

示例

给定

M
的形状
(3, 4, 8)

tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],

        [[-1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1., -1., -1., -1.]],

        [[-2., -2., -2., -2., -2., -2., -2., -2.],
         [-2., -2., -2., -2., -2., -2., -2., -2.],
         [-2., -2., -2., -2., -2., -2., -2., -2.],
         [-2., -2., -2., -2., -2., -2., -2., -2.]]])
形状

I

(4, 8)
:

tensor([[2, 0, 2, 0, 1, 0, 1, 0],
        [2, 2, 1, 0, 0, 2, 1, 0],
        [2, 0, 0, 2, 1, 1, 0, 0],
        [0, 1, 0, 0, 2, 0, 2, 1]], dtype=torch.int32)

生成的图像将是:

tensor([[-2.,  0., -2.,  0., -1.,  0., -1.,  0.],
        [-2., -2., -1.,  0.,  0., -2., -1.,  0.],
        [-2.,  0.,  0., -2., -1., -1.,  0.,  0.],
        [ 0., -1.,  0.,  0., -2.,  0., -2., -1.]])

注1

我不关心

M
尺寸的顺序,如果它提供了更简单的解决方案,它也可以是
(W, H, B)

注2

我也对 NumPy 解决方案感兴趣。

python numpy indexing torch
1个回答
0
投票

一种解决方案是:

indices = torch.meshgrid(torch.arange(I.shape[0]), torch.arange(I.shape[1]))
result = M[I, *indices]

或使用numpy:

indices = np.indices(I)
result = M[I, indices[0], indices[1]]
© www.soinside.com 2019 - 2024. All rights reserved.