Numpy/PyTorch 索引,重复分配值

问题描述 投票:0回答:1

我正在尝试将作业和高级索引结合起来。我有一个大小为 (batch, x, y, width) 的数组,我试图将其分配给一个 (batch, width) 数组,以便这些值在 x, y 上重复,但我的形状不会广播。代码的慢速版本看起来像

g = x.clone() # size is batch, x, y, width


for i in batch_size:
   g[i, :, :, :] = x[i, goals[i, 1], goals[i, 0], :] # goals is of size (batch_size, 2)

我尝试通过执行以下操作来使此代码更快:

goal_y = goal[:, 1]
goal_x = goal[:, 0]
g[:, :, :, :] = x[torch.arange(batch_size), goal_y, goal_x, :]

但我收到形状错误,因为我们无法将 [batch_size, width] 的形状广播到 [batch_size, x, y, width] 中。这里的想法是它只会在维度 x, y 上重复,我可以使用 numpy.repeat 扩展 x,但这有点恶心。有更好的方法吗?非常感谢任何帮助。

python numpy indexing pytorch
1个回答
0
投票

解决了。关键是使用“无”。

g[:, None, None, :] = x[torch.arange(batchsize), goal_y, goal_x, :]
© www.soinside.com 2019 - 2024. All rights reserved.