查找每行中连续零的最大长度

问题描述 投票:0回答:1

我的目标是找到每行中连续零的最大长度。例如,如果我有一个像

这样的张量
input = torch.tensor([[0, 1, 0, 0, 0, 1],[0, 0, 1, 0, 1, 0],[1, 0, 0, 0, 0, 0]])

我期待得到结果

tensor([3, 2, 5])

我已经使用 numpy 完成了此操作(下面列出,假设“输入”是一个 numpy 二进制矩阵)并且它往往非常高效,但我似乎找不到一种使用至少同样高效的张量的方法。我尝试使用火炬操作遵循类似的逻辑,但性能总是更差。 谢谢!

import numpy as np
input = np.array([[0, 1, 0, 0, 0, 1],[0, 0, 1, 0, 1, 0],[1, 0, 0, 0, 0, 0]])
# Pad the matrix with ones
padded_matrix = np.pad(input, ((0, 0), (1, 1)), constant_values=1)

# Compute differences
diffs = np.diff(padded_matrix, axis=1)

# Identify start and end of zero runs
start_indices = np.where(diffs == -1)
end_indices = np.where(diffs == 1)

#Compute lengths of zero runs
run_lengths = end_indices[1] - start_indices[1]

# Create a result array initialized with zeros
max_zeros = np.zeros(binaryGrid.shape[0], dtype=int)

# Use np.maximum.at to find the maximum run length for each row
np.maximum.at(max_zeros, start_indices[0], run_lengths)

print(max_zeros)
python numpy pytorch torch
1个回答
0
投票

对于火炬实施,您可以使用

scatter_reduce
代替
np.maximum.at

def test_np(x):
    padded_matrix = np.pad(x, ((0, 0), (1, 1)), constant_values=1)

    diffs = np.diff(padded_matrix, axis=1)

    start_indices = np.where(diffs == -1)
    end_indices = np.where(diffs == 1)
    run_lengths = end_indices[1] - start_indices[1]

    max_zeros = np.zeros(x.shape[0], dtype=int)
    np.maximum.at(max_zeros, start_indices[0], run_lengths)
    return max_zeros

def test_torch(x):
    padded_matrix = F.pad(x, (1,1), value=1)
    diffs = torch.diff(padded_matrix, axis=1)
    
    start_indices = torch.where(diffs==-1)
    end_indices = torch.where(diffs == 1)
    run_lengths = end_indices[1] - start_indices[1]
    max_zeros = torch.zeros(x.shape[0]).long().scatter_reduce(
        dim=0,
        index=start_indices[0],
        src=run_lengths,
        reduce='amax'
    )

    return max_zeros

rows = 12
cols = 32
thresh = 0.5
x_torch = (torch.rand(rows, cols)>thresh).long()
x_np = x_torch.numpy()

result_torch = test_torch(x_torch)
result_np = test_np(x_np)

在我的系统上进行基准测试,两个版本在 CPU 上的性能相当。

© www.soinside.com 2019 - 2024. All rights reserved.