我正在使用 PyTorch 的 DataLoader 来加载我的数据集。我注意到,当我设置 num_workers > 0 时,我的程序在训练期间无限期挂起。但是,当 num_workers = 0 时,它工作正常。
这是我的代码的简化版本:
class MedianFilter:
def __init__(self, kernel_size=3):
self.kernel_size = kernel_size
def __call__(self, img):
return img
train_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
MedianFilter(),
transforms.RandomAffine(degrees=40, translate=(0.125, 0.125)),
transforms.RandomResizedCrop(size=(28, 28), scale=(1, 1), ratio=(1, 1), interpolation=InterpolationMode.BILINEAR),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
MedianFilter(),
transforms.ToTensor()
])
train_dataset = ImageFolder(root='../Dataset/Original/train/', transform=train_transform)
val_dataset = ImageFolder(root='../Dataset/Original/val/', transform=val_transform)
train_dataloader = DataLoader(train_dataset, batch_size=64, pin_memory=True, num_workers=3, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, pin_memory=True, num_workers=3)
dataloader = {'train':train_dataloader, 'val':val_dataloader}
def train_model(model, dataloader, criterion, optimizer, scheduler, num_epochs):
acc_history = {'train' : [], 'val' : []}
loss_history = {'train' : [], 'val' : []}
best_acc = 0.0
for epoch in range(1, num_epochs+1):
print(f"Epoch{epoch}:")
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_correct = 0
running_loss = 0.0
totalIm = 0
for data, _label in dataloader[phase]: # Stuck here
在此代码中,
MedianFilter
类是一个简单的恒等函数。尽管如此,当num_workers > 0
时,程序仍然挂起。
为什么会发生这种情况以及如何解决这个问题?
我尝试将
MedianFilter
类简化为一个简单的恒等函数,仅返回输入图像。尽管进行了这种简化,程序在 num_workers > 0
时仍然挂起。我预计此更改将解决该问题,因为 MedianFilter 类不再进行任何重要的计算。然而,问题仍然存在。
我还尝试在没有自定义 torchvision 转换和设置 num_workers > 0
的情况下运行代码。在这种情况下,代码将按预期运行。
听起来没有工作人员来响应数据加载器的查询。您可以通过破坏程序并查看程序被破坏时的位置来进行验证。如果您在工作循环中发现它,请检查您如何初始化工作人员。