PyTorch 稀疏张量的散射最大值?

问题描述 投票:0回答:1

我有以下 PyTorch 代码

value_tensor = torch.sparse_coo_tensor(indices=query_indices.t(), values=values, size=(num_lines, img_size, img_size)).to(device=device)
value_tensor = value_tensor.to_dense()
indices = torch.arange(0, img_size * img_size).repeat(len(lines)).to(device=device)
line_tensor_flat = value_tensor.flatten()
img, _ = scatter_max(line_tensor_flat, indices, dim=0)
img = torch.reshape(img, (img_size, img_size))

注意这行:

value_tensor = value_tensor.to_dense()
,这毫不奇怪地慢。

但是,我无法弄清楚如何使用稀疏张量获得相同的结果。该函数调用

reshape
,这在稀疏张量上不可用。我正在使用 Scatter Max,但愿意使用任何有效的东西。

python pytorch
1个回答
0
投票

如果您将传递给

scatter_max
的索引保持稀疏(即仅非零索引),您应该能够直接在稀疏张量上使用
scatter_max

考虑这个例子

query_indices = torch.tensor([
    [0, 0, 0, 1, 1, 1],
    [0, 1, 2, 0, 1, 2],
    [0, 1, 0, 0, 1, 0]
])

values = torch.tensor([1, 2, 3, 4, 5, 6])
num_lines = 2
img_size = 3

value_tensor = torch.sparse_coo_tensor(
    indices=query_indices,
    values=values,
    size=(num_lines, img_size, img_size)
)

# need to coalesce because for some reason sparse_coo_tensor doesn't guarantee uniqueness of indices
value_tensor = value_tensor.coalesce()

然后,将

flat_indices
计算为仅包含非零一维索引的稀疏张量(二维索引转换为类似于您的
arange
的一维索引)

indices = value_tensor.indices()
values = value_tensor.values()

batch_indices = indices[0]        # "line" (in your terminology) indices
row_indices = indices[1]
col_indices = indices[2]
flat_indices = row_indices * img_size + col_indices


您可以使用

flat_indices
scatter_max

flattened_result, _ = scatter_max(
    values, flat_indices, dim=0, dim_size=img_size * img_size
)

per_line_max = flattened_result.reshape(img_size, img_size)

indices

tensor([[0, 0, 0, 1, 1, 1],
        [0, 1, 2, 0, 1, 2],
        [0, 1, 0, 0, 1, 0]])


values

tensor([1, 2, 3, 4, 5, 6])


flat_indices

tensor([0, 4, 6, 0, 4, 6])

per_line_max

tensor([[4, 0, 0],
        [0, 5, 0],
        [6, 0, 0]])

我得到的输出与我从你的代码中得到的输出相同。

© www.soinside.com 2019 - 2024. All rights reserved.