我正在尝试编写用于对象检测的代码。我使用了自己的数据集,一切似乎都很好。但是在训练中,我得到了提到的错误。我将 collate_fn 定义为:
def collate_fn(batch):
return tuple(zip(*batch))
然后 data_loader 为:
data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, shuffle=True,
collate_fn=collate_fn)
最后,在训练阶段,我得到这个错误:
num_epochs = 5
losses = []
for epoch in range(num_epochs):
for i, (inputs, targets) in enumerate(data_loader):
inputs = torch.tensor(inputs).to(device)
targets = torch.tensor(targets).to(device)
错误是:
RuntimeError Traceback (most recent call last)
<ipython-input-76-f5d22c9fc6ce> in <module>
4 for epoch in range(num_epochs):
5 for i, (inputs, targets) in enumerate(data_loader):
----> 6 inputs = torch.tensor(inputs).to(device)
7 targets = torch.tensor(targets).to(device)
8
RuntimeError: Could not infer dtype of Image
一开始我的错误是:
'tuple' object has no attribute 'to'
当我的代码受阻时:
inputs= inputs.to(device)
但是将代码更改为:
inputs = torch.tensor(inputs).to(device)
也许解决了那个错误。但是现在,我得到了提到的错误
RuntimeError: Could not infer dtype of Image
这里有什么问题? 我该怎么办?