如何让`__getitems__`返回一个字典?

问题描述 投票:0回答:1

在 torch 的

Dataset
中,在强制性
__getitem__
方法之上,您可以实现
__getitems__
方法。

在我的例子中,

__getitem__
返回一个字典,但我不知道如何用
__getitems__
做同样的事情。

class StackOverflowDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self._data = data

    def __getitem__(self, idx):
        return {'item': self._data[idx], 'whatever': idx*self._data[idx]+3}

    def __getitems__(self, idxs):
        return {'item': self._data[idxs], 'whatever': idxs*self._data[idxs]+3}
    
    def __len__(self):
        return len(self._data)

dataset = StackOverflowDataset(np.random.random(5))
for X in DataLoader(dataset, 2):
    print(X)
    break

如果我注释掉

__getitems__
它可以工作,但是将其留在那里会引发
KeyError: 0

KeyError                                  Traceback (most recent call last)
Cell In[182], line 15
     12         return len(self._data)
     14 dataset = StackOverflowDataset(np.random.random(5))
---> 15 for X in DataLoader(dataset, 2):
     16     print(X)
     17     break

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
    671 def _next_data(self):
    672     index = self._next_index()  # may raise StopIteration
--> 673     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    674     if self._pin_memory:
    675         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:55, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     53 else:
     54     data = self.dataset[possibly_batched_index]
---> 55 return self.collate_fn(data)

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py:317, in default_collate(batch)
    256 def default_collate(batch):
    257     r"""
    258     Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
    259 
   (...)
    315         >>> default_collate(batch)  # Handle `CustomType` automatically
    316     """
--> 317     return collate(batch, collate_fn_map=default_collate_fn_map)

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py:137, in collate(batch, collate_fn_map)
    109 def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
    110     r"""
    111     General collate function that handles collection type of element within each batch.
    112 
   (...)
    135         for the dictionary of collate functions as `collate_fn_map`.
    136     """
--> 137     elem = batch[0]
    138     elem_type = type(elem)
    140     if collate_fn_map is not None:

KeyError: 0
python pytorch pytorch-dataloader
1个回答
0
投票

这是因为 pytorch 尝试通过索引访问数据,从 0 开始。 官方文档说:

子类还可以选择实现 getitems(),以提高速度 批量样品加载。此方法接受索引列表 批次样品并返回样品列表。

换句话说,

__getitems__
应该返回列表,而不是字典。

© www.soinside.com 2019 - 2024. All rights reserved.