输入类型(double)和偏置类型(float)应该相同

问题描述 投票:0回答:1

我想训练 CNN 来重建图像,并从其他人那里获取了代码(虽然它对他们来说可能是旧版本的 python),但我不断收到此错误。

    from skimage.data import brain
    from skimage.transform import resize, rescale

    import numpy as np
    import torch, torch.nn as nn
    import cv2

    def cv2disp(win, ima, xp, yp, sc): cv2.imshow(win, rescale(ima, sc, False) *1.0/(np.max(ima)+1e-15)); cv2.moveWindow(win, xp, yp)
    def np_to_00torch(np_array):   return torch.from_numpy(np_array).unsqueeze(0).unsqueeze(0)
    def torch_to_np(torch_array):  return np.squeeze(torch_array.detach().cpu().numpy())
    device  = torch.device("cuda:o" if torch.cuda.is_available() else "cpu"); print(device)

    nxd = 128; disp_scale = 3
    nrd = int(nxd * 1.42); nphi = int(nxd*1.42)

    brainimage=brain()
    true_object_np      = resize(brainimage[9,30:-1, :-30], (nxd,nxd), anti_aliasing=False)
    true_object_torch   = np_to_00torch(true_object_np).to(device)


    cv2disp("True", true_object_np, 0, 0, disp_scale)



    #--------------------------------------Torch System Matrix----------------------------------------------
    def make_torch_system_matrix(nxd, nrd, nphi):
        system_matrix = torch.zeros(nrd*nphi, nxd*nxd) # rows = num sino bins, cols = num image pixels
        for xv in range(nxd):
            for yv in range(nxd):
                for ph in range(nphi):
                yp = -(xv-(nxd*0.5)) * np.sin(ph*np.pi/nphi)+(yv-(nxd*0.5)) * np.cos(ph*np.pi/nphi)
                yp_bin=int(yp+nrd/2)
                system_matrix[yp_bin + ph*nrd, xv + yv*nxd] = 1.0
    return system_matrix


    def fp_system_torch(image, sys_mat, nxd, nrd, nphi):
        fp =  torch.reshape(image, (nxd*nxd,1))
        fb = torch.mm(sys_mat, fp)
        return torch.reshape(fb, (nphi, nrd))
    def bp_system_torch(sino, sys_mat, nxd, nrd, nphi):
           return torch.reshape(torch.mm(sys_mat.T, torch.reshape(sino, (nrd*nphi,1))), (nxd,nxd))



    sys_mat  = make_torch_system_matrix(nxd, nrd, nphi).to(device)
    sys_mat = sys_mat.to(torch.float64)


    true_sinogram_torch  = fp_system_torch(true_object_torch, sys_mat, nxd, nrd, nphi)


    cv2disp("Sinogram", torch_to_np(true_sinogram_torch), disp_scale*nxd, 0, disp_scale)

    #------------------------------------------End FP---------------------------------------------------


    #######################################CNN##########################################################

    class CNN(nn.Module):
        def __init__(self, num_channels):
            super(CNN, self).__init__()
            self.CNN = nn.Sequential(
                nn.Conv2d(1, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, num_channels, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
                nn.Conv2d(num_channels, 1, 3, padding=(1, 1), padding_mode='reflect'), nn.PReLU(),
        )
        def forward(self,x):
            x = torch.squeeze(self.CNN(x.unsqueeze(0).unsqueeze(0)))
                return x

    cnn = CNN(32).to(device)

    #######################################################################################################################


    #-------------------------------------------FBP------------------------------------------------------------------------

    class FBP_CNN_Net(nn.Module):
        def __init__(self, cnn, sino_for_reconstruction):
            super(FBP_CNN_Net, self).__init__()
            self.sino_ones = torch.ones_like(sino_for_reconstruction)
            self.sens_image = bp_system_torch(self.sino_ones, sys_mat, nxd, nrd, nphi)
            self.cnn = cnn
            self.prelu = nn.PReLU()

        def forward(self, sino_for_reconstruction):
            filtered_sino = self.cnn(sino_for_reconstruction)
            recon = bp_system_torch(filtered_sino, sys_mat, nxd, nrd, nphi) / (self.sens_image+1.0e-15)
            recon = self.prelu(recon)
            fpsino = fp_system_torch(recon, sys_mat, nxd, nrd, nphi)


            cv2disp("Filtered Sino", torch_to_np(filtered_sino), disp_scale*(nxd+nrd), 0, disp_scale)
            cv2disp("FBP", torch_to_np(recon), disp_scale*(nxd+nrd), disp_scale*nphi+30, disp_scale)
            cv2disp("FP Sino", torch_to_np(fpsino), disp_scale*(nxd+nrd+nphi), 0, disp_scale)
            cv2disp("Difference", torch_to_np(fpsino - sino_for_reconstruction), disp_scale*(nxd+nrd+nrd), disp_scale)
            cv2.waitKey(1)
            return recon, fpsino

    fbpnet = FBP_CNN_Net(cnn, true_sinogram_torch).to(device)

    #===========================================TRAINING OF THE NETWORK

    loss_fun = nn.MSELoss()
    optimiser = torch.optim.Adam(fbpnet.parameters(), lr = 1e-4)

    train_loss = list()
epochs = 60000

    for ep in range(epochs):
        fbp_recon, rec_fp = fbpnet(true_sinogram_torch)
        loss = loss_fun(rec_fp, torch.squeeze(true_sinogram_torch))
        train_loss.append(loss.item())
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        print('Epoch %d Training loss = %f' % ( ep, train_loss[-1]))

但错误告诉我 RuntimeError: 输入类型 (double) 和偏置类型 (float) 应该相同。有谁知道这意味着什么以及如何解决它。我尝试分块运行代码,当我最后运行 for 循环来训练网络时,失败发生了。我不太确定输入是什么、偏差是什么以及如何更改它们(或将它们更改为什么)。

python image-processing pytorch pycharm conv-neural-network
1个回答
0
投票

您的 numpy 数组属于

float64
类型,而您的 pytorch 数组属于
float32
类型。当您调用
torch.from_numpy
时,numpy 数据类型将被保留。

找到代码中将 numpy 数组转换为 torch 张量的所有位置,并在输出上调用

.float()

© www.soinside.com 2019 - 2024. All rights reserved.