如何替换 torch 张量的值?

问题描述 投票:0回答:1

我有一些这样的数据:

tensor([0.0872, 0.4737, 0.0954])

我想把它改成

tensor([0,0,0])

只要值大于 0,我就赋值 0,只要值小于 0,我就赋值 1。

我想可能会有像 lamda x 这样的单行代码来处理这个问题,但我是 Python 和 PyTorch 的新手,所以真的不知道该怎么做。

如有任何建议,我们将不胜感激。

python torch
1个回答
0
投票
import torch


def  check_data(data):
    threshold = 0

    # Use boolean indexing to get the indices where the values are greater than the threshold
    indices_greater = data > threshold

    # Use the torch.where() function to replace the values
    data[indices_greater] = 0
    data[~indices_greater] = 1

    print(data) # tensor([1., 0., 1.])


data = torch.tensor([0.0872, 0.4737, 0.0954])
check_data(data)
data = torch.tensor([-0.0872, -0.4737, -0.0954])
check_data(data)

© www.soinside.com 2019 - 2024. All rights reserved.