获得训练集标签的问题

问题描述 投票:1回答:1

我已经使用train_test_split函数将我的数据划分为X_trainX_testy_trainy_test,然后使用utils.data.DataLoader将其提供给我的CNN,但问题是我没有知道如何访问标签张量以创建混淆矩阵并将其与我的预测张量进行比较。我知道这是一个基本问题,但是无论如何,我们都会感谢您的帮助。

X_train, X_test, y_train, y_test = train_test_split(faces, emotions, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=41)

我用过

train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)

用于将数据馈送到我的网络似乎您可以通过在train_set之后键入目标属性(例如train_set.targets)来访问标签,但这种方式不适用于我。如何获得标签?

python label classification conv-neural-network dataloader
1个回答
0
投票

PyTorch的DataLoader对象大致是这样使用的:

for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

通常,我建议使用两个DataLoader,一个用于训练,另一个用于测试/验证。由于您要创建一个混淆矩阵,因此您可以简单地通过numpy数组y_train和预测preds访问标签,例如通过将它们在循环内连接到一个numpy数组。

有关如何使用DataLoader的更多信息,我建议看一下这个非常好的教程:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

© www.soinside.com 2019 - 2024. All rights reserved.