我有以下 PyTorch 代码
value_tensor = torch.sparse_coo_tensor(indices=query_indices.t(), values=values, size=(num_lines, img_size, img_size)).to(device=device)
value_tensor = value_tensor.to_dense()
indices = torch.arange(0, img_size * img_size).repeat(len(lines)).to(device=device)
line_tensor_flat = value_tensor.flatten()
img, _ = scatter_max(line_tensor_flat, indices, dim=0)
img = torch.reshape(img, (img_size, img_size))
注意这行:
value_tensor = value_tensor.to_dense()
,这毫不奇怪地慢。
但是,我无法弄清楚如何使用稀疏张量获得相同的结果。该函数调用
reshape
,这在稀疏张量上不可用。我正在使用 Scatter Max,但愿意使用任何有效的东西。
如果您将传递给
scatter_max
的索引保持稀疏(即仅非零索引),您应该能够直接在稀疏张量上使用 scatter_max
。
考虑这个例子
query_indices = torch.tensor([
[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2],
[0, 1, 0, 0, 1, 0]
])
values = torch.tensor([1, 2, 3, 4, 5, 6])
num_lines = 2
img_size = 3
value_tensor = torch.sparse_coo_tensor(
indices=query_indices,
values=values,
size=(num_lines, img_size, img_size)
)
# need to coalesce because for some reason sparse_coo_tensor doesn't guarantee uniqueness of indices
value_tensor = value_tensor.coalesce()
然后,将
flat_indices
计算为仅包含非零一维索引的稀疏张量(二维索引转换为类似于您的 arange
的一维索引)
indices = value_tensor.indices()
values = value_tensor.values()
batch_indices = indices[0] # "line" (in your terminology) indices
row_indices = indices[1]
col_indices = indices[2]
flat_indices = row_indices * img_size + col_indices
您可以使用
flat_indices
到 scatter_max
flattened_result, _ = scatter_max(
values, flat_indices, dim=0, dim_size=img_size * img_size
)
per_line_max = flattened_result.reshape(img_size, img_size)
indices
tensor([[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2],
[0, 1, 0, 0, 1, 0]])
values
tensor([1, 2, 3, 4, 5, 6])
flat_indices
tensor([0, 4, 6, 0, 4, 6])
per_line_max
tensor([[4, 0, 0],
[0, 5, 0],
[6, 0, 0]])
我得到的输出与我从你的代码中得到的输出相同。