我正在尝试使用来自torchvision.datasets
的MNIST数据集。它似乎是以N x H x W (uint8)
(批量维度,高度,宽度)张量提供的。然而,用于处理图像的所有pytorch类(例如Conv2d
)都需要N x C x H x W (float32)
张量,其中C
是颜色通道的数量。我试图添加添加ToTensor
变换,但没有添加颜色通道。
有没有办法使用torchvision.transforms
添加这个额外的维度?对于原始的tensor
,我们可以做.unsqueeze(1)
,但这看起来不是一个非常优雅的解决方案。我只是想以“正确”的方式做到这一点。
这是失败的转换。
import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])
我有一个误解:dataset.train_data
不受指定的transform
影响,只有DataLoader(dataset,...)
的输出。检查data
后
for data, _ in DataLoader(dataset):
break
我们可以看到ToTensor
实际上完全符合预期。