我正在使用数据集创建一个数据加载器。
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)
数据集每个索引包含 2 个元素:一个矩阵和一个数组。
a, b = train_ds.__getitem__(0)
print(type(a))
print(a)
print(type(b))
print(b)
返回:
<class 'numpy.ndarray'>
[[ 44.9329 46.08967 44.9329 ... 99.2188 99.735664 99.17029 ]
[ 44.9329 44.9329 44.9329 ... 114.164474 114.292244 114.11395 ]
[ 44.9329 45.03071 44.9329 ... 114.57378 114.56599 114.49552 ]
...
[ 44.9329 44.9329 44.9329 ... 52.242996 50.12293 44.9329 ]
[ 44.9329 44.9329 44.9329 ... 44.9329 44.9329 44.9329 ]
[ 44.9329 44.9329 44.9329 ... 44.9329 44.9329 44.9329 ]]
<class 'numpy.ndarray'>
[-0.002963486 -0.003033393 0.00371422 2.02e-06 0.004402838 -0.002704915
0.003289625 -0.002551801 -0.003632823 -0.003408553 -0.002707387
0.00278949 0.000828761 0.000849513 0.003992096 -0.002692624 0.001183484
9.43e-05 0.003836168 2.24e-05 0.003944455 -0.001950883 -0.000877485
0.001734729 -0.003225849 -0.000537016 6.53e-05 -0.003643878 -0.002444321
0.002499692 0.001538219 0.002263657 0.003073046 0.004134932 -0.002500862
-0.001662471 0.002273667 0.00375025 0.001866289 -0.002027481 0.002197658
-0.002243473 0.000943156 -0.000643054 -0.003169563 -0.003424202
0.00118924 -0.003570424 0.002273526]
但是当尝试使用以下方式迭代我的数据加载器时:
for i, data in enumerate(train_dl):
我收到错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in
1 num_epochs=1 # Just for demo, adjust this higher.
----> 2 training(myModel, train_dl, num_epochs)
in training(model, train_dl, num_epochs)
18
19 # Repeat for each batch in the training set
---> 20 for i, data in enumerate(train_dl):
21 # Get the input features and target labels, and put them on the GPU
22 inputs, labels = data[0].to(device), data[1].to(device)
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\dataloader.py in __next__(self)
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\dataloader.py in _next_data(self)
675 def _next_data(self):
676 index = self._next_index() # may raise StopIteration
--> 677 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
678 if self._pin_memory:
679 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
52 else:
53 data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
263 >>> default_collate(batch) # Handle `CustomType` automatically
264 """
--> 265 return collate(batch, collate_fn_map=default_collate_fn_map)
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate(batch, collate_fn_map)
140
141 if isinstance(elem, tuple):
--> 142 return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
143 else:
144 try:
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in (.0)
140
141 if isinstance(elem, tuple):
--> 142 return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
143 else:
144 try:
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate(batch, collate_fn_map)
117 if collate_fn_map is not None:
118 if elem_type in collate_fn_map:
--> 119 return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
120
121 for collate_type in collate_fn_map:
~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate_numpy_array_fn(batch, collate_fn_map)
167 # array of string classes and object
168 if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
--> 169 raise TypeError(default_collate_err_msg_format.format(elem.dtype))
170
171 return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
为什么它会抛出错误并返回只找到一个对象?
您遇到的错误是因为 PyTorch 的默认整理函数 (
default_collate
) 不知道如何处理数据集中的自定义对象。
错误消息“TypeError:default_collate:batch必须包含张量,numpy数组,数字,字典或列表;找到对象”表示
default_collate
期望列出的类型之一,但找到了不同的类型(“对象”) .
在您的情况下,您的数据集中有 numpy 数组,DataLoader 尝试将其整理(组合)成一个批次。 PyTorch 的 DataLoader 期望数据采用 PyTorch 张量的形式。因此,在将数据发送到 DataLoader 之前,您应该将 numpy 数组转换为 PyTorch 张量。
解决问题的方法如下:
__getitem__
方法中将 numpy 数组转换为 PyTorch 张量。__getitem__
方法返回数据时使用这些张量。这是数据集
__getitem__
的修改版本:
import torch
class YourDataset(torch.utils.data.Dataset):
def __init__(self, ...): # your other initialization arguments
...
def __getitem__(self, index):
a, b = ... # however you're getting your numpy arrays currently
a_tensor = torch.from_numpy(a)
b_tensor = torch.from_numpy(b)
return a_tensor, b_tensor
通过这样做,您可以确保每当访问数据集中的项目时,它都已经是张量形式,DataLoader 可以毫无问题地处理它。