这是我的代码;首先,我将 u-net 模型定义为 nn.Module 类,如下代码所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
class unet(nn.Module):
def __init__(self):
super(unet, self).__init__()
self.conv1 = nn.Conv3d(1, 32, 3, padding=1)
self.conv1_1 = nn.Conv3d(32, 32, 3, padding=1)
self.conv2 = nn.Conv3d(32, 64, 3, padding=1)
self.conv2_2 = nn.Conv3d(64, 64, 3, padding=1)
self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
self.conv3_3 = nn.Conv3d(128, 128, 3, padding=1)
self.convT1 = nn.ConvTranspose3d(128, 64, 3, stride=(2,2,2), padding=1, output_padding=1)
self.conv4 = nn.Conv3d(128, 64, 3, padding=1)
self.conv4_4 = nn.Conv3d(64, 64, 3, padding=1)
self.convT2 = nn.ConvTranspose3d(64, 32, 3,stride=(2,2,2), padding=1, output_padding=1)
self.conv5 = nn.Conv3d(64, 32, 3, padding=1)
self.conv5_5 = nn.Conv3d(32, 32, 3, padding=1)
self.conv6 = nn.Conv3d(32, 1 ,3, padding=1)
def forward(self, inputs):
conv1 = F.relu(self.conv1(inputs))
conv1 = F.relu(self.conv1_1(conv1))
pool1 = F.max_pool3d(conv1, 2)
conv2 = F.relu(self.conv2(pool1))
conv2 = F.relu(self.conv2_2(conv2))
pool2 = F.max_pool3d(conv2, 2)
conv3 = F.relu(self.conv3(pool2))
conv3 = F.relu(self.conv3_3(conv3))
conv3 = self.convT1(conv3)
up1 = torch.cat((conv3, conv2), dim=1)
conv4 = F.relu(self.conv4(up1))
conv4 = F.relu(self.conv4_4(conv4))
conv4 = self.convT2(conv4)
up2 = torch.cat((conv4, conv1), dim=1)
conv5 = F.relu(self.conv5(up2))
conv5 = F.relu(self.conv5_5(conv5))
conv6 = F.relu(self.conv6(conv5))
return conv6
然后我像下面的代码一样运行我的unet。请注意,在定义模块时,我将其设置为 cuda.我还将输入数据及其标签设置到 cuda。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = unet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
datasets = torch.utils.data.TensorDataset(data_recon, data_truth)
train_loader = DataLoader(datasets, batch_size=2, shuffle=True)
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
for epoch in range(1, n_epochs + 1):
loss_train = 0
for imgs, labels in train_loader:
imgs.to(device)
labels.to(device)
outputs = model(imgs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train += loss.item()
print('{} Epoch {}, Training loss {}'.format(datetime.datetime.now(), epoch, float(loss_train)))
training_loop(50, optimizer, model, loss_fn, train_loader)
但我收到此错误:
RuntimeError Traceback (most recent call last) <ipython-input-31-573c18dee5b1> in <module>
----> 1 training_loop(50, optimizer, model, loss_fn, train_loader)
<ipython-input-30-81cea9bcd2ec> in training_loop(n_epochs, optimizer, model, loss_fn, train_loader)
5 imgs.to(device)
6 labels.to(device)
----> 7 outputs = model(imgs)
8 loss = loss_fn(outputs, labels)
9
/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
<ipython-input-15-5dac9d28f19c> in forward(self, inputs)
18
19 def forward(self, inputs):
---> 20 conv1 = F.relu(self.conv1(inputs))
21 conv1 = F.relu(self.conv1_1(conv1))
22 pool1 = F.max_pool3d(conv1, 2)
/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
474 self.dilation, self.groups)
475 return F.conv3d(input, self.weight, self.bias, self.stride,
--> 476 self.padding, self.dilation, self.groups)
477
478
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'weight'
如何解决这个问题?
问题出在这一行
imgs.to(device)
labels.to(device)
.to(device)
返回一个新的张量,并且不会改变imgs
和labels
。所以cuda错误是有效的。您可以通过分配新张量来简单地修复它,如下所示:
imgs = imgs.to(device)
labels = labels.to(device)