我如何解析包含张量值列的 csv 文件

问题描述 投票:0回答:1

为了节省我笔记本的内存。我将图像处理(转换)输出保存到数据集中到一个逗号分隔值文件中,我可以将其加载到笔记本中。但是在尝试使用各种方法解析它之后,它仍然被识别为字符串。

这是一个张量在逗号分隔值文件中的样子的示例,它是一个字符串:

'张量([[[-0.8849, -0.8849, -0.9192, ..., -1.4329, -2.1179, -2.1179], [-0.9192, -0.8849, -0.8678, ..., -1.3987, -2.1179, -2.1179], [-0.9020, -0.8849, -0.8678, ..., -1.3644, -2.1179, -2.1179], ..., [-2.1179, -2.1179, -1.4158, ..., -0.9877, -1.0048, -0.9877], [-2.1179, -2.1179, -1.3644, ..., -0.9877, -1.0048, -1.0048], [-2.1179, -2.1179, -1.3130, ..., -0.9705, -1.0390, -1.0390]], [[-0.3200, -0.3200, -0.3550, ..., -1.1779, -2.0357, -2.0357], [-0.3550, -0.3200, -0.3025, ..., -1.1604, -2.0357, -2.0357], [-0.3375, -0.3200, -0.3200, ..., -1.1604, -2.0357, -2.0357], ..., [-2.0357, -2.0357, -1.3004, ..., -1.0903, -1.1078, -1.0903], [-2.0357, -2.0357, -1.2829, ..., -1.0903, -1.1078, -1.1253], [-2.0357, -2.0357, -1.2654, ..., -1.1078, -1.1779, -1.1779]], [[-1.1944, -1.1770, -1.1770, ..., -1.6650, -1.8044, -1.8044], [-1.1944, -1.1596, -1.1247, ..., -1.6476, -1.8044, -1.8044], [-1.1421, -1.1247, -1.1073, ..., -1.6302, -1.8044, -1.8044], ..., [-1.8044, -1.8044, -1.5604, ..., -1.0724, -1.0898, -1.0724], [-1.8044, -1.8044, -1.5081, ..., -1.0724, -1.1073, -1.1073], [-1.8044, -1.8044, -1.4559, ..., -1.0724, -1.1770, -1.1770]]])'

这是错误:---------------------------------------- ------------------------------ ----> 3 train_set["0"] = train_set["0"].apply(parse_tensor_string)

ValueError:无法将字符串转换为浮点数:'[-0.4226'

我试过这个:

def parse_tensor_string(tensor_string): # Extract the flattened tensor data from the string start_index = tensor_string.find("[[") end_index = tensor_string.find("]]") data_string = tensor_string[start_index+2:end_index]

# Remove any extraneous characters
data_string = data_string.replace("\n", "").replace("  ", "").replace(",", "")

# Convert each substring to a float
tensor_list = data_string.split()
tensor_floats = [float(s) for s in tensor_list]

# Reshape the resulting array to match the original tensor shape
tensor_array = np.array(tensor_floats).reshape((3, 224, 224))

# Return the tensor as a numpy array
return tensor_array
train_set["0"] = train_set["0"]. apply(parse_tensor_string)
python pandas deep-learning pytorch image-classification
1个回答
1
投票

很遗憾,从您提供的字符串来看是不可能的。您保存的字符串是张量的截断字符串表示(用符号

...
表示)。这意味着只有张量的“边缘”在字符串中。原来大小为(3, 224, 224)的张量根本不在字符串中,无法加载

但是,我已经修改了你的函数来解析你提供的字符串中可能的所有内容:

修改后的代码:

import ast

def parse_tensor_string(data_string):
    # Remove extraneous characters
    result = data_string.replace("\n", "").replace("...", "").replace(" ", "").replace(",,", ",").replace("tensor(", "").replace(")", "")
    # Parse to list
    result = ast.literal_eval(result)
    # Convert list to tensor
    return torch.tensor(result)

输出:

tensor([[[-0.8849, -0.8849, -0.9192, -1.4329, -2.1179, -2.1179],
         [-0.9192, -0.8849, -0.8678, -1.3987, -2.1179, -2.1179],
         [-0.9020, -0.8849, -0.8678, -1.3644, -2.1179, -2.1179],
         [-2.1179, -2.1179, -1.4158, -0.9877, -1.0048, -0.9877],
         [-2.1179, -2.1179, -1.3644, -0.9877, -1.0048, -1.0048],
         [-2.1179, -2.1179, -1.3130, -0.9705, -1.0390, -1.0390]],

        [[-0.3200, -0.3200, -0.3550, -1.1779, -2.0357, -2.0357],
         [-0.3550, -0.3200, -0.3025, -1.1604, -2.0357, -2.0357],
         [-0.3375, -0.3200, -0.3200, -1.1604, -2.0357, -2.0357],
         [-2.0357, -2.0357, -1.3004, -1.0903, -1.1078, -1.0903],
         [-2.0357, -2.0357, -1.2829, -1.0903, -1.1078, -1.1253],
         [-2.0357, -2.0357, -1.2654, -1.1078, -1.1779, -1.1779]],

        [[-1.1944, -1.1770, -1.1770, -1.6650, -1.8044, -1.8044],
         [-1.1944, -1.1596, -1.1247, -1.6476, -1.8044, -1.8044],
         [-1.1421, -1.1247, -1.1073, -1.6302, -1.8044, -1.8044],
         [-1.8044, -1.8044, -1.5604, -1.0724, -1.0898, -1.0724],
         [-1.8044, -1.8044, -1.5081, -1.0724, -1.1073, -1.1073],
         [-1.8044, -1.8044, -1.4559, -1.0724, -1.1770, -1.1770]]])
© www.soinside.com 2019 - 2024. All rights reserved.