无效的参数组合 - eq()

问题描述 投票:0回答:1

我正在使用代码共享here来测试CNN图像分类器。当我调用测试函数时,我在line 155上遇到了这个错误:

test_acc += torch.sum(prediction == labels.data)
TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor other)
      didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m)
 * (Number other)
      didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m)

test功能的片段:

def test():
    model.eval()
    test_acc = 0.0
    for i, (images, labels) in enumerate(test_loader):

        if cuda_avail:
                images = Variable(images.cuda())
                labels = Variable(labels.cuda())

        #Predict classes using images from the test set
        outputs = model(images)
        _,prediction = torch.max(outputs.data, 1)
        prediction = prediction.cpu().numpy()
        test_acc += torch.sum(prediction == labels.data) #line 155



    #Compute the average acc and loss over all 10000 test images
    test_acc = test_acc / 10000

return test_acc

经过快速搜索后,我发现错误可能与predictionlabels之间的比较有关,就像在这个SO question中看到的那样。

我该如何修复此问题而不是扰乱其余的代码?

python numpy image-processing machine-learning pytorch
1个回答
1
投票

你为什么在这里有.numpy() prediction = prediction.cpu().numpy()?这样你就可以将PyTorch张量转换为NumPy数组,使其与labels.data不兼容。

删除.numpy()部分应该解决问题。

© www.soinside.com 2019 - 2024. All rights reserved.