我有一些这样的数据:
tensor([0.0872, 0.4737, 0.0954])
我想把它改成
tensor([0,0,0])
只要值大于 0,我就赋值 0,只要值小于 0,我就赋值 1。
我想可能会有像 lamda x 这样的单行代码来处理这个问题,但我是 Python 和 PyTorch 的新手,所以真的不知道该怎么做。
如有任何建议,我们将不胜感激。
import torch
def check_data(data):
threshold = 0
# Use boolean indexing to get the indices where the values are greater than the threshold
indices_greater = data > threshold
# Use the torch.where() function to replace the values
data[indices_greater] = 0
data[~indices_greater] = 1
print(data) # tensor([1., 0., 1.])
data = torch.tensor([0.0872, 0.4737, 0.0954])
check_data(data)
data = torch.tensor([-0.0872, -0.4737, -0.0954])
check_data(data)