我正在尝试将作业和高级索引结合起来。我有一个大小为 (batch, x, y, width) 的数组,我试图将其分配给一个 (batch, width) 数组,以便这些值在 x, y 上重复,但我的形状不会广播。代码的慢速版本看起来像
g = x.clone() # size is batch, x, y, width
for i in batch_size:
g[i, :, :, :] = x[i, goals[i, 1], goals[i, 0], :] # goals is of size (batch_size, 2)
我尝试通过执行以下操作来使此代码更快:
goal_y = goal[:, 1]
goal_x = goal[:, 0]
g[:, :, :, :] = x[torch.arange(batch_size), goal_y, goal_x, :]
但我收到形状错误,因为我们无法将 [batch_size, width] 的形状广播到 [batch_size, x, y, width] 中。这里的想法是它只会在维度 x, y 上重复,我可以使用 numpy.repeat 扩展 x,但这有点恶心。有更好的方法吗?非常感谢任何帮助。
解决了。关键是使用“无”。
g[:, None, None, :] = x[torch.arange(batchsize), goal_y, goal_x, :]