错误:default_collate:批处理必须包含张量、numpy 数组、数字、字典或列表;发现物体

问题描述 投票:0回答:1

我正在使用数据集创建一个数据加载器。

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)

数据集每个索引包含 2 个元素:一个矩阵和一个数组。

a, b = train_ds.__getitem__(0)
print(type(a))
print(a)
print(type(b))
print(b)

返回:

<class 'numpy.ndarray'>

[[ 44.9329    46.08967   44.9329   ...  99.2188    99.735664  99.17029 ]
 [ 44.9329    44.9329    44.9329   ... 114.164474 114.292244 114.11395 ]
 [ 44.9329    45.03071   44.9329   ... 114.57378  114.56599  114.49552 ]
 ...
 [ 44.9329    44.9329    44.9329   ...  52.242996  50.12293   44.9329  ]
 [ 44.9329    44.9329    44.9329   ...  44.9329    44.9329    44.9329  ]
 [ 44.9329    44.9329    44.9329   ...  44.9329    44.9329    44.9329  ]]

<class 'numpy.ndarray'>

[-0.002963486 -0.003033393 0.00371422 2.02e-06 0.004402838 -0.002704915
 0.003289625 -0.002551801 -0.003632823 -0.003408553 -0.002707387
 0.00278949 0.000828761 0.000849513 0.003992096 -0.002692624 0.001183484
 9.43e-05 0.003836168 2.24e-05 0.003944455 -0.001950883 -0.000877485
 0.001734729 -0.003225849 -0.000537016 6.53e-05 -0.003643878 -0.002444321
 0.002499692 0.001538219 0.002263657 0.003073046 0.004134932 -0.002500862
 -0.001662471 0.002273667 0.00375025 0.001866289 -0.002027481 0.002197658
 -0.002243473 0.000943156 -0.000643054 -0.003169563 -0.003424202
 0.00118924 -0.003570424 0.002273526]

但是当尝试使用以下方式迭代我的数据加载器时:

for i, data in enumerate(train_dl):

我收到错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
 in 
      1 num_epochs=1   # Just for demo, adjust this higher.
----> 2 training(myModel, train_dl, num_epochs)

 in training(model, train_dl, num_epochs)
     18 
     19     # Repeat for each batch in the training set
---> 20     for i, data in enumerate(train_dl):
     21         # Get the input features and target labels, and put them on the GPU
     22         inputs, labels = data[0].to(device), data[1].to(device)

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\dataloader.py in __next__(self)
    631                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632                 self._reset()  # type: ignore[call-arg]
--> 633             data = self._next_data()
    634             self._num_yielded += 1
    635             if self._dataset_kind == _DatasetKind.Iterable and \

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    675     def _next_data(self):
    676         index = self._next_index()  # may raise StopIteration
--> 677         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    678         if self._pin_memory:
    679             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     52         else:
     53             data = self.dataset[possibly_batched_index]
---> 54         return self.collate_fn(data)

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
    263             >>> default_collate(batch)  # Handle `CustomType` automatically
    264     """
--> 265     return collate(batch, collate_fn_map=default_collate_fn_map)

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate(batch, collate_fn_map)
    140 
    141         if isinstance(elem, tuple):
--> 142             return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
    143         else:
    144             try:

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in (.0)
    140 
    141         if isinstance(elem, tuple):
--> 142             return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
    143         else:
    144             try:

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate(batch, collate_fn_map)
    117     if collate_fn_map is not None:
    118         if elem_type in collate_fn_map:
--> 119             return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
    120 
    121         for collate_type in collate_fn_map:

~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate_numpy_array_fn(batch, collate_fn_map)
    167     # array of string classes and object
    168     if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
--> 169         raise TypeError(default_collate_err_msg_format.format(elem.dtype))
    170 
    171     return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

为什么它会抛出错误并返回只找到一个对象?

python pytorch torch pytorch-dataloader
1个回答
0
投票

您遇到的错误是因为 PyTorch 的默认整理函数 (

default_collate
) 不知道如何处理数据集中的自定义对象。

错误消息“TypeError:default_collate:batch必须包含张量,numpy数组,数字,字典或列表;找到对象”表示

default_collate
期望列出的类型之一,但找到了不同的类型(“对象”) .

在您的情况下,您的数据集中有 numpy 数组,DataLoader 尝试将其整理(组合)成一个批次。 PyTorch 的 DataLoader 期望数据采用 PyTorch 张量的形式。因此,在将数据发送到 DataLoader 之前,您应该将 numpy 数组转换为 PyTorch 张量。

解决问题的方法如下:

  1. 在数据集的
    __getitem__
    方法中将 numpy 数组转换为 PyTorch 张量。
  2. __getitem__
    方法返回数据时使用这些张量。

这是数据集

__getitem__
的修改版本:

import torch

class YourDataset(torch.utils.data.Dataset):
    def __init__(self, ...):  # your other initialization arguments
        ...

    def __getitem__(self, index):
        a, b = ...  # however you're getting your numpy arrays currently
        a_tensor = torch.from_numpy(a)
        b_tensor = torch.from_numpy(b)
        return a_tensor, b_tensor

通过这样做,您可以确保每当访问数据集中的项目时,它都已经是张量形式,DataLoader 可以毫无问题地处理它。

© www.soinside.com 2019 - 2024. All rights reserved.