我目前正在研究涉及图像数据的二元分类任务。首先,我必须检查我的数据集。但是,我遇到了
DataLoader
的问题。
PyTorch官方网站上有这样写的
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
当他们设置
training data
时,他们将数据类型转换为张量。他们只是使用 imshow(matplotlib)。但是当我自己尝试这个过程时,错误TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
困扰着我。
当我向 GPT4 询问这个问题时,它说“PyTorch 和 matplotlib 是兼容的。”然而,当我再次询问我提供的代码时,它提到:“在使用 imshow 之前,您需要将 PyTorch 张量转换为 NumPy 数组。”哪一种说法是准确的?
第二个应该是正确的说法。你应该改成这个
plt.imshow(img.squeeze().numpy(), cmap="gray")