在PyTorch中切片具有3D张量索引的4D张量

问题描述 投票:1回答:1

我有一个4D张量(恰好是一个56x56图像的三批堆栈,其中每批具有16张图像),大小为[16、3、56、56] 。我的目标是为每个像素选择这三批中的正确批处理(我的索引图的大小为[16,56,56])并获得所需的图像。

现在,我要选择这三批中的特定批图像,其值具有如

       [[[ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  2,  2,  ...,  0,  0,  0]],

        [[ 0,  2,  0,  ...,  1,  1,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  2,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  0,  0]]]

因此对于0,将从第一批中选择值,其中1和2表示我要从第二和第三批中选择值。

这里是一些索引的可视化效果,每种颜色表示另一批次。

enter image description here

enter image description here

我已经尝试过转置4D张量以匹配索引的尺寸,但是没有用。它所做的就是给我一份我尝试选择的尺寸的副本。方式

tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())

输出

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

虽然我需要的形状是

torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])

例如,如果我尝试仅使用批处理中的第一个图像选择正确的值,则>]

fourD[0,indices].size()

我得到一个类似的形状

torch.Size([16, 56, 56, 56, 56])

更不用说当我在整个张量上尝试时出现内存不足错误。

我非常感谢使用这些索引为我的图像中的每个像素选择这三批中的任一批

注意:

我尝试过该选项

outs[indices[:,None,:,:]].size()

返回]

torch.Size([16, 1, 56, 56, 3, 56, 56])

编辑:torch.take并没有多大帮助,因为它将输入张量视为一维数组。

我有一个4D张量(恰好是三批56x56图像的堆栈,每批具有16张图像),其大小为[16、3、56、56]。我的目标是从这三个中选择正确的一个...

pytorch slice tensor
1个回答
1
投票

原来PyTorch中有一个功能具有我要搜索的功能。

torch.gather(fourD, 1, indices.unsqueeze(1)) 
© www.soinside.com 2019 - 2024. All rights reserved.