如何将一组张量值转换为列表,同时保持格式浮点精度

问题描述 投票:0回答:1

我有 3 个 id 列表(sequence[i].id_column),当我打印它们时,我看到(我不知道出于什么原因)第二个列表的格式发生了变化!例如 4.2 变为 4.2e+00。我尝试将所有这些列表添加到一个集合(unique_values)中,然后将其转换为列表(unique_values_tensor),以使其中的值是唯一的(unique_values-sorted)。但是我看到 unique_values_sorted 有不同的格式。请帮助我如何将所有这三个列表保留在一个具有唯一值(没有重复值)的列表中,并采用正确的格式,如第一个 id_column_tensor 显示的值。

        unique_values = set()
        np.set_printoptions(suppress=True)
        torch.set_printoptions(precision=1)  
        for i in range(start_timestep-1, start_timestep+ timestep-1):
           if sequence[i].x.size(0) > 0:
                
                id_column_tensor = torch.tensor(sequence[i].id_column, dtype=torch.float32)

                id_column_list = [float(f"{value:.4f}") if isinstance(value, float) else value for value in sequence[i].id_column]  #id_column_tensor
                unique_values.update(id_column_tensor)
                print('id_column_tensor:',id_column_tensor)
            else:
                raise ValueError(f"The tensor at sequence[{i}].x is empty.")
        print('unique_values:',unique_values)
        # Convert unique values to a tensor to get unique sorted values
        unique_values_tensor = torch.tensor(list(unique_values))   #, dtype=torch.float32
        unique_values_sorted = torch.unique(unique_values_tensor, sorted=True)
        
        print('unique_values_sorted:', unique_values_sorted)   
    
id_column_tensor: tensor([  12.0,   13.0,   13.1,   13.2,   14.0,   14.1,   15.0,   15.1,   16.0,
          16.1,   16.2,   17.0,   17.1,   18.0,   18.1,   18.2,   19.0,   19.1,
        5001.0, 5004.0, 5008.0, 5010.0])
id_column_tensor: tensor([0.0e+00, 1.0e+00, 2.0e+00, 3.0e+00, 3.1e+00, 3.2e+00, 4.0e+00, 4.1e+00,
        4.2e+00, 5.0e+00, 5.1e+00, 6.0e+00, 6.1e+00, 6.2e+00, 7.0e+00, 7.1e+00,
        8.0e+00, 8.1e+00, 8.2e+00, 9.0e+00, 9.1e+00, 9.2e+00, 1.0e+01, 1.1e+01,
        1.1e+01, 1.1e+01, 1.2e+01, 1.2e+01, 1.3e+01, 1.3e+01, 1.4e+01, 5.0e+03,
        5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03])
id_column_tensor: tensor([   5.1,    6.0,    6.1,    6.2,    7.0,    7.1,    8.0,    8.1,    8.2,
           9.0,    9.1,    9.2,   10.0,   11.0,   11.1,   11.2,   12.0,   12.1,
          13.0,   13.1,   13.2,   14.0,   14.1,   15.0,   15.1,   16.0,   16.1,
          16.2,   17.0,   17.1,   18.0,   18.1,   18.2,   19.0,   19.1, 5001.0,
        5004.0, 5006.0, 5007.0, 5008.0, 5009.0, 5010.0])
unique_values: {tensor(5001.), tensor(3.1), tensor(16.1), tensor(2.), tensor(4.1), tensor(10.), tensor(16.2), tensor(13.2), tensor(16.1), tensor(5004.), tensor(5005.), tensor(13.2), tensor(6.), tensor(19.), tensor(5009.), tensor(11.1), tensor(7.), tensor(13.1), tensor(17.), tensor(8.2), tensor(9.1), tensor(11.), tensor(18.1), tensor(5006.), tensor(14.), tensor(5008.), tensor(6.2), tensor(19.1), tensor(18.), tensor(14.), tensor(11.2), tensor(7.), tensor(13.1), tensor(8.1), tensor(9.), tensor(5.), tensor(11.1), tensor(9.), tensor(5007.), tensor(14.1), tensor(6.1), tensor(5001.), tensor(5008.), tensor(18.2), tensor(4.2), tensor(8.2), tensor(5001.), tensor(12.), tensor(9.2), tensor(17.1), tensor(6.2), tensor(11.2), tensor(9.1), tensor(13.), tensor(5008.), tensor(15.), tensor(5.1), tensor(6.), tensor(5.1), tensor(5004.), tensor(16.), tensor(17.), tensor(5002.), tensor(12.1), tensor(14.1), tensor(17.1), tensor(5010.), tensor(18.), tensor(12.), tensor(9.2), tensor(6.1), tensor(7.1), tensor(15.1), tensor(8.1), tensor(5006.), tensor(5010.), tensor(14.), tensor(19.), tensor(5003.), tensor(13.), tensor(18.1), tensor(12.1), tensor(10.), tensor(12.), tensor(16.2), tensor(16.), tensor(15.1), tensor(3.), tensor(3.2), tensor(8.), tensor(5007.), tensor(15.), tensor(0.), tensor(8.), tensor(5004.), tensor(13.1), tensor(7.1), tensor(19.1), tensor(4.), tensor(18.2), tensor(1.), tensor(13.), tensor(11.)}
unique_values_sorted: tensor([0.0e+00, 1.0e+00, 2.0e+00, 3.0e+00, 3.1e+00, 3.2e+00, 4.0e+00, 4.1e+00,
        4.2e+00, 5.0e+00, 5.1e+00, 6.0e+00, 6.1e+00, 6.2e+00, 7.0e+00, 7.1e+00,
        8.0e+00, 8.1e+00, 8.2e+00, 9.0e+00, 9.1e+00, 9.2e+00, 1.0e+01, 1.1e+01,
        1.1e+01, 1.1e+01, 1.2e+01, 1.2e+01, 1.3e+01, 1.3e+01, 1.3e+01, 1.4e+01,
        1.4e+01, 1.5e+01, 1.5e+01, 1.6e+01, 1.6e+01, 1.6e+01, 1.7e+01, 1.7e+01,
        1.8e+01, 1.8e+01, 1.8e+01, 1.9e+01, 1.9e+01, 5.0e+03, 5.0e+03, 5.0e+03,
        5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03])
python pytorch tensor
1个回答
0
投票

问题只是其中一种打印格式,而不是张量的数据格式。有些东西会触发你的张量以有时被称为科学记数法的方式打印。

您可以通过

torch.set_printoptions()
手动关闭科学计数法打印:

import torch

t = torch.tensor([0.0, 1.0, 5000])

print(t)
# >>> tensor([0.0000e+00, 1.0000e+00, 5.0000e+03])

torch.set_printoptions(sci_mode=False)  # Disable scientific notation for printing

print(t)
# >>> tensor([        0.,         1.,      5000.])

请注意,这是一个全局变化,即所有未来的张量打印输出 你的程序的运行时间将会受到影响。当然,您可以随时通过再次调用

torch.set_printoptions()
来设置新的打印选项。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.