为了节省我笔记本的内存。我将图像处理(转换)输出保存到数据集中到一个逗号分隔值文件中,我可以将其加载到笔记本中。但是在尝试使用各种方法解析它之后,它仍然被识别为字符串。
这是一个张量在逗号分隔值文件中的样子的示例,它是一个字符串:
'张量([[[-0.8849, -0.8849, -0.9192, ..., -1.4329, -2.1179, -2.1179], [-0.9192, -0.8849, -0.8678, ..., -1.3987, -2.1179, -2.1179], [-0.9020, -0.8849, -0.8678, ..., -1.3644, -2.1179, -2.1179], ..., [-2.1179, -2.1179, -1.4158, ..., -0.9877, -1.0048, -0.9877], [-2.1179, -2.1179, -1.3644, ..., -0.9877, -1.0048, -1.0048], [-2.1179, -2.1179, -1.3130, ..., -0.9705, -1.0390, -1.0390]], [[-0.3200, -0.3200, -0.3550, ..., -1.1779, -2.0357, -2.0357], [-0.3550, -0.3200, -0.3025, ..., -1.1604, -2.0357, -2.0357], [-0.3375, -0.3200, -0.3200, ..., -1.1604, -2.0357, -2.0357], ..., [-2.0357, -2.0357, -1.3004, ..., -1.0903, -1.1078, -1.0903], [-2.0357, -2.0357, -1.2829, ..., -1.0903, -1.1078, -1.1253], [-2.0357, -2.0357, -1.2654, ..., -1.1078, -1.1779, -1.1779]], [[-1.1944, -1.1770, -1.1770, ..., -1.6650, -1.8044, -1.8044], [-1.1944, -1.1596, -1.1247, ..., -1.6476, -1.8044, -1.8044], [-1.1421, -1.1247, -1.1073, ..., -1.6302, -1.8044, -1.8044], ..., [-1.8044, -1.8044, -1.5604, ..., -1.0724, -1.0898, -1.0724], [-1.8044, -1.8044, -1.5081, ..., -1.0724, -1.1073, -1.1073], [-1.8044, -1.8044, -1.4559, ..., -1.0724, -1.1770, -1.1770]]])'
这是错误:---------------------------------------- ------------------------------ ----> 3 train_set["0"] = train_set["0"].apply(parse_tensor_string)
ValueError:无法将字符串转换为浮点数:'[-0.4226'
我试过这个:
def parse_tensor_string(tensor_string): # Extract the flattened tensor data from the string start_index = tensor_string.find("[[") end_index = tensor_string.find("]]") data_string = tensor_string[start_index+2:end_index]
# Remove any extraneous characters
data_string = data_string.replace("\n", "").replace(" ", "").replace(",", "")
# Convert each substring to a float
tensor_list = data_string.split()
tensor_floats = [float(s) for s in tensor_list]
# Reshape the resulting array to match the original tensor shape
tensor_array = np.array(tensor_floats).reshape((3, 224, 224))
# Return the tensor as a numpy array
return tensor_array
train_set["0"] = train_set["0"]. apply(parse_tensor_string)
很遗憾,从您提供的字符串来看是不可能的。您保存的字符串是张量的截断字符串表示(用符号
...
表示)。这意味着只有张量的“边缘”在字符串中。原来大小为(3, 224, 224)的张量根本不在字符串中,无法加载
但是,我已经修改了你的函数来解析你提供的字符串中可能的所有内容:
修改后的代码:
import ast
def parse_tensor_string(data_string):
# Remove extraneous characters
result = data_string.replace("\n", "").replace("...", "").replace(" ", "").replace(",,", ",").replace("tensor(", "").replace(")", "")
# Parse to list
result = ast.literal_eval(result)
# Convert list to tensor
return torch.tensor(result)
输出:
tensor([[[-0.8849, -0.8849, -0.9192, -1.4329, -2.1179, -2.1179],
[-0.9192, -0.8849, -0.8678, -1.3987, -2.1179, -2.1179],
[-0.9020, -0.8849, -0.8678, -1.3644, -2.1179, -2.1179],
[-2.1179, -2.1179, -1.4158, -0.9877, -1.0048, -0.9877],
[-2.1179, -2.1179, -1.3644, -0.9877, -1.0048, -1.0048],
[-2.1179, -2.1179, -1.3130, -0.9705, -1.0390, -1.0390]],
[[-0.3200, -0.3200, -0.3550, -1.1779, -2.0357, -2.0357],
[-0.3550, -0.3200, -0.3025, -1.1604, -2.0357, -2.0357],
[-0.3375, -0.3200, -0.3200, -1.1604, -2.0357, -2.0357],
[-2.0357, -2.0357, -1.3004, -1.0903, -1.1078, -1.0903],
[-2.0357, -2.0357, -1.2829, -1.0903, -1.1078, -1.1253],
[-2.0357, -2.0357, -1.2654, -1.1078, -1.1779, -1.1779]],
[[-1.1944, -1.1770, -1.1770, -1.6650, -1.8044, -1.8044],
[-1.1944, -1.1596, -1.1247, -1.6476, -1.8044, -1.8044],
[-1.1421, -1.1247, -1.1073, -1.6302, -1.8044, -1.8044],
[-1.8044, -1.8044, -1.5604, -1.0724, -1.0898, -1.0724],
[-1.8044, -1.8044, -1.5081, -1.0724, -1.1073, -1.1073],
[-1.8044, -1.8044, -1.4559, -1.0724, -1.1770, -1.1770]]])