我想重写这个函数,而不使用 for 循环,按顺序用第一个非零元素替换零元素:
未优化的代码:
last_val=0
for i in range(t.shape[0]-1, -1, -1):
if (t[i] > 0):
last_val = t[i]
else:
t[i] = last_val
输入:
t = torch.tensor([0, 0, 5, 0, 0, 7, 0, 8, 9])
预期输出:
t = torch.tensor([5, 5, 5, 7, 7, 7, 8, 8, 9])
更新了
import numpy as np
t = np.array([0, 0, 5, 0, 0, 7, 0, 8, 9])
non_zero_indices = np.nonzero(t)
if len(non_zero_indices[0]) > 0:
first_non_zero_value = t[non_zero_indices][0]
for i in range(len(t)):
if t[i] == 0:
t[i] = first_non_zero_value
print(t)