数据加载器冻结

问题描述 投票:0回答:1

自定义数据集上的我的 Pytorch (1.11.0) 数据加载器偶尔会冻结。

我无法重现冻结,这似乎是随机的:它通常“运行”没有问题,但有时会卡住。当我中断它(ctrl+c)时,我读到了这个:

   idx, data = self._get_data()
  File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1163, in _get_data
    success, data = self._try_get_data()
  File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1011, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/opt/conda/envs/torch/lib/python3.8/queue.py", line 179, in get
    self.not_empty.wait(remaining)
  File "/opt/conda/envs/torch/lib/python3.8/threading.py", line 306, in wait
    gotit = waiter.acquire(True, timeout)
KeyboardInterrupt

来这里问是因为这个问题在官方论坛上已经提过好几次了,但是都没有回复。

我尝试循环访问数据集以捕获错误,但没有遇到任何问题:当我循环访问数据加载器时,它永远不会冻结。我正在 Kubernetes 上开发 Ubuntu 20.04 linux pod。

我知道Python中的并发性相当混乱,但是有人可以给我建议检查什么吗?

自定义数据集:

class MultiModalDataset(Dataset):
    def __init__(self, img_dataset: pd.DataFrame, text_dataset: pd.DataFrame, 
            img_fld: str, img_transforms=None, n_classes=None, img_size=224,
            n_sentences=1, n_tokens=12, collate_fn=None, l1normalization=False, verbose=False):
        super().__init__()
        self.n_classes = n_classes or img_dataset.shape[1]
        assert self.n_classes == img_dataset.shape[1]
        self.img_ds = img_dataset
        # print(text_dataset.head())
        self.text_ds = text_dataset.set_index("image_filename")
        self.img_fld = img_fld
        self.transforms = img_transforms
        self.img_size = img_size
        self.n_sentences = n_sentences
        self.n_tokens = n_tokens
        self.collate_fn = collate_fn
        self.l1normalization = l1normalization
        self.verbose = verbose

    def __len__(self):
        return len(self.img_ds)
    
    def __getitem__(self, idx):
        assert (idx >=0) and (idx < len(self.img_ds))
        item = self.img_ds.iloc[idx]
        filename = item.name
        labels = item.values
        if self.l1normalization:
            nlabs = sum(labels)
            assert nlabs > 0, f"dataset, at index {idx}, no labels found"
            labels = labels / nlabs
        
        text = self.text_ds.loc[filename, "enc_text"]

        if self.collate_fn is not None:
            padded_text = self.collate_fn(text, n_sents=self.n_sentences, max_tokens=self.n_tokens, verbose=self.verbose)
        else:
            padded_text = text

        return self.load_image(filename), torch.tensor(labels.astype(np.float32)), torch.tensor(padded_text)

    def load_image(self, img_filename):
        fn = join(self.img_fld, img_filename)
        img = Image.open(fn)
        if self.transforms is not None:
            img = self.transforms(img)
        return img

数据加载器:

DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=[False,False,False], pin_memory=False)

我使用以下方法迭代数据集/数据加载器:

for bi, (_, _, _) in enumerate(dataloader):
 ...
  • pin_memory
    切换为
    True
    并不能解决问题。
  • 有些人建议将
    num_workers
    设置为零,但我不能:它变得太慢了。将工作人员数量更改为任何其他值 > 0 对冻结没有影响(它仍然冻结)。
python pytorch
1个回答
0
投票

我不完全明白为什么,但对我来说这解决了问题:

if __name__ == '__main__':
    import torch
    torch.multiprocessing.set_start_method('spawn')
    main()
© www.soinside.com 2019 - 2024. All rights reserved.