我有一批图像和每张图像的一批索引(x,y)。每个图像的索引都不同,所以我不能使用简单的索引。使用每个图像所选像素的颜色获取另一批次的最佳或最快方法是什么?
n_images = 4
width = 100
height = 100
channels = 3
n_samples = 30
images = torch.rand((n_images, height, width, channels))
indices = (torch.rand((n_images, n_samples, 2)) * width).to(torch.int32)
# preferred function
# result = images[indices]
# with result.shape = (n_images, n_samples, 3)
# I just found this solution but I would rather like to call a general torch function
xs = indices.reshape((-1, 2))[:, 0]
ys = indices.reshape((-1, 2))[:, 1]
ix = torch.arange(n_images, dtype=torch.int32)
ix = ix[..., None].expand((-1, n_samples)).flatten()
result = images[ix, ys, xs].reshape((n_images, n_samples, 3))
要使用torch功能,您需要展平您需要计算的尺寸以及索引,然后您可以使用
torch.gather
选择您需要的像素:
n_images = 4
width = 100
height = 100
channels = 3
n_samples = 30
images = torch.rand((n_images, height, width, channels))
indices = (torch.rand((n_images, n_samples, 2)) * width).long() #<-- cast this to long instead of int as `torch.gather` requires long as index
flatten_images = images.view(n_images,-1, channels)
flatten_indices = (indices[..., 1:2] * width + indices[..., 0:1]).repeat(1,1,channels)
output = torch.gather(flatten_images, 1, flatten_indices).view(n_images, n_samples, channels) #