RuntimeError:无法推断图像的数据类型

问题描述 投票:0回答:0

我正在尝试编写用于对象检测的代码。我使用了自己的数据集,一切似乎都很好。但是在训练中,我得到了提到的错误。我将 collate_fn 定义为:

def collate_fn(batch):
    return tuple(zip(*batch))

然后 data_loader 为:

data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, shuffle=True, 
                                         collate_fn=collate_fn)

最后,在训练阶段,我得到这个错误:

num_epochs = 5
losses = []

for epoch in range(num_epochs):
   for i, (inputs, targets) in enumerate(data_loader):
       inputs = torch.tensor(inputs).to(device)
       targets = torch.tensor(targets).to(device)

错误是:

RuntimeError                              Traceback (most recent call last)
<ipython-input-76-f5d22c9fc6ce> in <module>
      4 for epoch in range(num_epochs):
      5     for i, (inputs, targets) in enumerate(data_loader):
----> 6         inputs = torch.tensor(inputs).to(device)
      7         targets = torch.tensor(targets).to(device)
      8 

RuntimeError: Could not infer dtype of Image

一开始我的错误是:

'tuple' object has no attribute 'to'

当我的代码受阻时:

inputs= inputs.to(device)

但是将代码更改为:

inputs = torch.tensor(inputs).to(device)

也许解决了那个错误。但是现在,我得到了提到的错误

RuntimeError: Could not infer dtype of Image

这里有什么问题? 我该怎么办?

python deep-learning pytorch neural-network
© www.soinside.com 2019 - 2024. All rights reserved.