我有 3 个 id 列表(sequence[i].id_column),当我打印它们时,我看到(我不知道出于什么原因)第二个列表的格式发生了变化!例如 4.2 变为 4.2e+00。我尝试将所有这些列表添加到一个集合(unique_values)中,然后将其转换为列表(unique_values_tensor),以使其中的值是唯一的(unique_values-sorted)。但是我看到 unique_values_sorted 有不同的格式。请帮助我如何将所有这三个列表保留在一个具有唯一值(没有重复值)的列表中,并采用正确的格式,如第一个 id_column_tensor 显示的值。
unique_values = set()
np.set_printoptions(suppress=True)
torch.set_printoptions(precision=1)
for i in range(start_timestep-1, start_timestep+ timestep-1):
if sequence[i].x.size(0) > 0:
id_column_tensor = torch.tensor(sequence[i].id_column, dtype=torch.float32)
id_column_list = [float(f"{value:.4f}") if isinstance(value, float) else value for value in sequence[i].id_column] #id_column_tensor
unique_values.update(id_column_tensor)
print('id_column_tensor:',id_column_tensor)
else:
raise ValueError(f"The tensor at sequence[{i}].x is empty.")
print('unique_values:',unique_values)
# Convert unique values to a tensor to get unique sorted values
unique_values_tensor = torch.tensor(list(unique_values)) #, dtype=torch.float32
unique_values_sorted = torch.unique(unique_values_tensor, sorted=True)
print('unique_values_sorted:', unique_values_sorted)
id_column_tensor: tensor([ 12.0, 13.0, 13.1, 13.2, 14.0, 14.1, 15.0, 15.1, 16.0,
16.1, 16.2, 17.0, 17.1, 18.0, 18.1, 18.2, 19.0, 19.1,
5001.0, 5004.0, 5008.0, 5010.0])
id_column_tensor: tensor([0.0e+00, 1.0e+00, 2.0e+00, 3.0e+00, 3.1e+00, 3.2e+00, 4.0e+00, 4.1e+00,
4.2e+00, 5.0e+00, 5.1e+00, 6.0e+00, 6.1e+00, 6.2e+00, 7.0e+00, 7.1e+00,
8.0e+00, 8.1e+00, 8.2e+00, 9.0e+00, 9.1e+00, 9.2e+00, 1.0e+01, 1.1e+01,
1.1e+01, 1.1e+01, 1.2e+01, 1.2e+01, 1.3e+01, 1.3e+01, 1.4e+01, 5.0e+03,
5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03])
id_column_tensor: tensor([ 5.1, 6.0, 6.1, 6.2, 7.0, 7.1, 8.0, 8.1, 8.2,
9.0, 9.1, 9.2, 10.0, 11.0, 11.1, 11.2, 12.0, 12.1,
13.0, 13.1, 13.2, 14.0, 14.1, 15.0, 15.1, 16.0, 16.1,
16.2, 17.0, 17.1, 18.0, 18.1, 18.2, 19.0, 19.1, 5001.0,
5004.0, 5006.0, 5007.0, 5008.0, 5009.0, 5010.0])
unique_values: {tensor(5001.), tensor(3.1), tensor(16.1), tensor(2.), tensor(4.1), tensor(10.), tensor(16.2), tensor(13.2), tensor(16.1), tensor(5004.), tensor(5005.), tensor(13.2), tensor(6.), tensor(19.), tensor(5009.), tensor(11.1), tensor(7.), tensor(13.1), tensor(17.), tensor(8.2), tensor(9.1), tensor(11.), tensor(18.1), tensor(5006.), tensor(14.), tensor(5008.), tensor(6.2), tensor(19.1), tensor(18.), tensor(14.), tensor(11.2), tensor(7.), tensor(13.1), tensor(8.1), tensor(9.), tensor(5.), tensor(11.1), tensor(9.), tensor(5007.), tensor(14.1), tensor(6.1), tensor(5001.), tensor(5008.), tensor(18.2), tensor(4.2), tensor(8.2), tensor(5001.), tensor(12.), tensor(9.2), tensor(17.1), tensor(6.2), tensor(11.2), tensor(9.1), tensor(13.), tensor(5008.), tensor(15.), tensor(5.1), tensor(6.), tensor(5.1), tensor(5004.), tensor(16.), tensor(17.), tensor(5002.), tensor(12.1), tensor(14.1), tensor(17.1), tensor(5010.), tensor(18.), tensor(12.), tensor(9.2), tensor(6.1), tensor(7.1), tensor(15.1), tensor(8.1), tensor(5006.), tensor(5010.), tensor(14.), tensor(19.), tensor(5003.), tensor(13.), tensor(18.1), tensor(12.1), tensor(10.), tensor(12.), tensor(16.2), tensor(16.), tensor(15.1), tensor(3.), tensor(3.2), tensor(8.), tensor(5007.), tensor(15.), tensor(0.), tensor(8.), tensor(5004.), tensor(13.1), tensor(7.1), tensor(19.1), tensor(4.), tensor(18.2), tensor(1.), tensor(13.), tensor(11.)}
unique_values_sorted: tensor([0.0e+00, 1.0e+00, 2.0e+00, 3.0e+00, 3.1e+00, 3.2e+00, 4.0e+00, 4.1e+00,
4.2e+00, 5.0e+00, 5.1e+00, 6.0e+00, 6.1e+00, 6.2e+00, 7.0e+00, 7.1e+00,
8.0e+00, 8.1e+00, 8.2e+00, 9.0e+00, 9.1e+00, 9.2e+00, 1.0e+01, 1.1e+01,
1.1e+01, 1.1e+01, 1.2e+01, 1.2e+01, 1.3e+01, 1.3e+01, 1.3e+01, 1.4e+01,
1.4e+01, 1.5e+01, 1.5e+01, 1.6e+01, 1.6e+01, 1.6e+01, 1.7e+01, 1.7e+01,
1.8e+01, 1.8e+01, 1.8e+01, 1.9e+01, 1.9e+01, 5.0e+03, 5.0e+03, 5.0e+03,
5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03, 5.0e+03])
问题只是其中一种打印格式,而不是张量的数据格式。有些东西会触发你的张量以有时被称为科学记数法的方式打印。
torch.set_printoptions()
手动关闭科学计数法打印:
import torch
t = torch.tensor([0.0, 1.0, 5000])
print(t)
# >>> tensor([0.0000e+00, 1.0000e+00, 5.0000e+03])
torch.set_printoptions(sci_mode=False) # Disable scientific notation for printing
print(t)
# >>> tensor([ 0., 1., 5000.])
请注意,这是一个全局变化,即所有未来的张量打印输出 你的程序的运行时间将会受到影响。当然,您可以随时通过再次调用
torch.set_printoptions()
来设置新的打印选项。