通过转换将通道添加到MNIST?

问题描述 投票:0回答:1

我正在尝试使用来自torchvision.datasets的MNIST数据集。它似乎是以N x H x W (uint8)(批量维度,高度,宽度)张量提供的。然而,用于处理图像的所有pytorch类(例如Conv2d)都需要N x C x H x W (float32)张量,其中C是颜色通道的数量。我试图添加添加ToTensor变换,但没有添加颜色通道。

有没有办法使用torchvision.transforms添加这个额外的维度?对于原始的tensor,我们可以做.unsqueeze(1),但这看起来不是一个非常优雅的解决方案。我只是想以“正确”的方式做到这一点。

这是失败的转换。

import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])
pytorch mnist python-3.7 torchvision
1个回答
0
投票

我有一个误解:dataset.train_data不受指定的transform影响,只有DataLoader(dataset,...)的输出。检查data

for data, _ in DataLoader(dataset):
    break

我们可以看到ToTensor实际上完全符合预期。

© www.soinside.com 2019 - 2024. All rights reserved.