在 torch 的
Dataset
中,在强制性 __getitem__
方法之上,您可以实现 __getitems__
方法。
在我的例子中,
__getitem__
返回一个字典,但我不知道如何用__getitems__
做同样的事情。
class StackOverflowDataset(torch.utils.data.Dataset):
def __init__(self, data):
self._data = data
def __getitem__(self, idx):
return {'item': self._data[idx], 'whatever': idx*self._data[idx]+3}
def __getitems__(self, idxs):
return {'item': self._data[idxs], 'whatever': idxs*self._data[idxs]+3}
def __len__(self):
return len(self._data)
dataset = StackOverflowDataset(np.random.random(5))
for X in DataLoader(dataset, 2):
print(X)
break
如果我注释掉
__getitems__
它可以工作,但是将其留在那里会引发 KeyError: 0
。
KeyError Traceback (most recent call last)
Cell In[182], line 15
12 return len(self._data)
14 dataset = StackOverflowDataset(np.random.random(5))
---> 15 for X in DataLoader(dataset, 2):
16 print(X)
17 break
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
671 def _next_data(self):
672 index = self._next_index() # may raise StopIteration
--> 673 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
674 if self._pin_memory:
675 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:55, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
53 else:
54 data = self.dataset[possibly_batched_index]
---> 55 return self.collate_fn(data)
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py:317, in default_collate(batch)
256 def default_collate(batch):
257 r"""
258 Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
259
(...)
315 >>> default_collate(batch) # Handle `CustomType` automatically
316 """
--> 317 return collate(batch, collate_fn_map=default_collate_fn_map)
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py:137, in collate(batch, collate_fn_map)
109 def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
110 r"""
111 General collate function that handles collection type of element within each batch.
112
(...)
135 for the dictionary of collate functions as `collate_fn_map`.
136 """
--> 137 elem = batch[0]
138 elem_type = type(elem)
140 if collate_fn_map is not None:
KeyError: 0
这是因为 pytorch 尝试通过索引访问数据,从 0 开始。 官方文档说:
子类还可以选择实现 getitems(),以提高速度 批量样品加载。此方法接受索引列表 批次样品并返回样品列表。
换句话说,
__getitems__
应该返回列表,而不是字典。