我的目标是找到每行中连续零的最大长度。例如,如果我有一个像
这样的张量input = torch.tensor([[0, 1, 0, 0, 0, 1],[0, 0, 1, 0, 1, 0],[1, 0, 0, 0, 0, 0]])
我期待得到结果
tensor([3, 2, 5])
我已经使用 numpy 完成了此操作(下面列出,假设“输入”是一个 numpy 二进制矩阵)并且它往往非常高效,但我似乎找不到一种使用至少同样高效的张量的方法。我尝试使用火炬操作遵循类似的逻辑,但性能总是更差。 谢谢!
import numpy as np
input = np.array([[0, 1, 0, 0, 0, 1],[0, 0, 1, 0, 1, 0],[1, 0, 0, 0, 0, 0]])
# Pad the matrix with ones
padded_matrix = np.pad(input, ((0, 0), (1, 1)), constant_values=1)
# Compute differences
diffs = np.diff(padded_matrix, axis=1)
# Identify start and end of zero runs
start_indices = np.where(diffs == -1)
end_indices = np.where(diffs == 1)
#Compute lengths of zero runs
run_lengths = end_indices[1] - start_indices[1]
# Create a result array initialized with zeros
max_zeros = np.zeros(binaryGrid.shape[0], dtype=int)
# Use np.maximum.at to find the maximum run length for each row
np.maximum.at(max_zeros, start_indices[0], run_lengths)
print(max_zeros)
对于火炬实施,您可以使用
scatter_reduce
代替 np.maximum.at
def test_np(x):
padded_matrix = np.pad(x, ((0, 0), (1, 1)), constant_values=1)
diffs = np.diff(padded_matrix, axis=1)
start_indices = np.where(diffs == -1)
end_indices = np.where(diffs == 1)
run_lengths = end_indices[1] - start_indices[1]
max_zeros = np.zeros(x.shape[0], dtype=int)
np.maximum.at(max_zeros, start_indices[0], run_lengths)
return max_zeros
def test_torch(x):
padded_matrix = F.pad(x, (1,1), value=1)
diffs = torch.diff(padded_matrix, axis=1)
start_indices = torch.where(diffs==-1)
end_indices = torch.where(diffs == 1)
run_lengths = end_indices[1] - start_indices[1]
max_zeros = torch.zeros(x.shape[0]).long().scatter_reduce(
dim=0,
index=start_indices[0],
src=run_lengths,
reduce='amax'
)
return max_zeros
rows = 12
cols = 32
thresh = 0.5
x_torch = (torch.rand(rows, cols)>thresh).long()
x_np = x_torch.numpy()
result_torch = test_torch(x_torch)
result_np = test_np(x_np)
在我的系统上进行基准测试,两个版本在 CPU 上的性能相当。