我在使用 pytorch 自定义数据集时遇到错误。这个问题对我来说真的很奇怪,因为它正在工作,我没有更改代码上的任何内容。这是场景:
构建深度学习模型后,我通过 10 到 100 个时期的训练来测试它。它工作正常,但我发现该模型需要进行更多轮数的训练才能获得更好的结果。
所以,我将纪元数更改为 500。我的 GPU 崩溃了,可能是因为我每 10 纪元打印一次结果,并且内存不足(我不知道真正的问题是什么)
现在重新启动 GPU 后,jupiter 笔记本向我抛出服务器错误,状态代码为 500
我在互联网上搜索并通过运行以下命令找到了我的 jupiter 笔记本的解决方案:
pip install --upgrade nbconvert
之后,代码就不再起作用了。我尝试调试它,但发现了一些奇怪的东西:
提前感谢您的回答。
这是自定义数据集类:
将其放入 src/custom_dataset.py 例如
import os
from natsort import natsorted
from PIL import Image
# Let's see if we have an available GPU
from datasets import Dataset
class LoadPairedDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
print(idx)
img_name = os.path.join(self.root_dir, self.images[idx])
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
return image
这是我用来调用自定义数据集的代码:
将其放入木星笔记本单元中
# Imports
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from src.custom_dataset import LoadPairedDataset
# Define your own class LoadFromFolder
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
print(idx)
img_name = os.path.join(self.root_dir, self.images[idx])
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
return image
base_path = "../lol-custom"
# dataloader = {"train_n": None, "train_p": None}
transform = transforms.Compose([
transforms.ToTensor()
])
train_data = CustomImageDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=5,
sampler=None,
num_workers=0
)
# The output will be:
# 0 1 2 3 4 from the print(idx) in the __getitem__ function in CustomImageDataset class
# torch.Size([5, 3, 400, 600]) from the below print
print(next(iter(dataloader)).shape) # This will print 0 1 2 3 4
print("######### The below throw an error ##############")
train_data = LoadPairedDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=5,
sampler=None,
num_workers=0
)
# The output will be:
# [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
# THEN ERROR: TypeError: list indices must be integers or slices, not list
print(next(iter(dataloader)).shape)
我正在使用 torch 2.01 和 python 3.9.18
最后,这是错误的输出和堆栈跟踪
0
1
2
3
4
torch.Size([5, 3, 400, 600])
######### The below throw an error ##############
[0, 1, 2, 3, 4]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 61
52 dataloader = torch.utils.data.DataLoader(
53 train_data,
54 batch_size=5,
55 sampler=None,
56 num_workers=0
57 )
58 # This output will be:
59 # [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
60 # THEN ERROR: TypeError: list indices must be integers or slices, not list
---> 61 print(next(iter(dataloader)).shape)
File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
675 def _next_data(self):
676 index = self._next_index() # may raise StopIteration
--> 677 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
678 if self._pin_memory:
679 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
47 if self.auto_collation:
48 if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
51 data = [self.dataset[idx] for idx in possibly_batched_index]
File ~\anaconda3\envs\mmie\lib\site-packages\datasets\arrow_dataset.py:2807, in Dataset.__getitems__(self, keys)
2805 def __getitems__(self, keys: List) -> List:
2806 """Can be used to get a batch using a list of integers indices."""
-> 2807 batch = self.__getitem__(keys)
2808 n_examples = len(batch[next(iter(batch))])
2809 return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
File ~\Projects\mmie\src\custom_dataset.py:21, in LoadPairedDataset.__getitem__(self, idx)
19 def __getitem__(self, idx):
20 print(idx)
---> 21 img_name = os.path.join(self.root_dir, self.images[idx])
22 image = Image.open(img_name)
24 if self.transform:
TypeError: list indices must be integers or slices, not list
这可能是因为您的自定义数据集继承自
Dataset
类,但Datset
的含义发生了变化。
在单独的文件中时,您将
Dataset
定义为 from datasets import Dataset
但在 jupyter 单元中,Dataset
是 from torch.utils.data import Dataset
,这显然是不同的。我建议您也保留在单独文件中有效的定义