使用火炬收集图像堆栈的每个图像的不同像素

问题描述 投票:0回答:1

我有一批图像和每张图像的一批索引(x,y)。每个图像的索引都不同,所以我不能使用简单的索引。使用每个图像所选像素的颜色获取另一批次的最佳或最快方法是什么?

    n_images = 4
    width = 100
    height = 100
    channels = 3
    n_samples = 30
    
    images = torch.rand((n_images, height, width, channels))
    indices = (torch.rand((n_images, n_samples, 2)) * width).to(torch.int32)

    # preferred function
    # result = images[indices]
    # with result.shape = (n_images, n_samples, 3)



    
    # I just found this solution but I would rather like to call a general torch function
    xs = indices.reshape((-1, 2))[:, 0]
    ys = indices.reshape((-1, 2))[:, 1]
    ix = torch.arange(n_images, dtype=torch.int32)
    ix = ix[..., None].expand((-1, n_samples)).flatten()
    
    result = images[ix, ys, xs].reshape((n_images, n_samples, 3))
pytorch torch
1个回答
0
投票

要使用torch功能,您需要展平您需要计算的尺寸以及索引,然后您可以使用

torch.gather
选择您需要的像素:

n_images = 4
width = 100
height = 100
channels = 3
n_samples = 30

images = torch.rand((n_images, height, width, channels))
indices = (torch.rand((n_images, n_samples, 2)) * width).long() #<-- cast this to long instead of int as `torch.gather` requires long as index

flatten_images = images.view(n_images,-1, channels)
flatten_indices = (indices[..., 1:2] * width + indices[..., 0:1]).repeat(1,1,channels)
output = torch.gather(flatten_images, 1, flatten_indices).view(n_images, n_samples, channels) # 
© www.soinside.com 2019 - 2024. All rights reserved.