我正在尝试微调 PyTorch 分类模型来对植物病害图像进行分类。我已正确初始化 CUDA 设备并将模型、训练、验证和测试数据发送到设备。然而,在训练模型时,它使用 100% CPU 和 0% GPU。为什么会出现这种情况?
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience):
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
best_val_loss = np.inf
patience_counter = 0
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
for epoch in range(num_epochs):
model.train()
running_loss, running_corrects, total = 0.0, 0, 0
for inputs, labels in tqdm(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
total += labels.size(0)
epoch_loss = running_loss / total
epoch_acc = running_corrects.double() / total
train_losses.append(epoch_loss)
train_accuracies.append(epoch_acc.item())
val_loss, val_acc = evaluate_model(model, val_loader, criterion)
val_losses.append(val_loss)
val_accuracies.append(val_acc)
print(f"Epoch {epoch}/{num_epochs-1}, Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print("Early stopping")
break
lr_scheduler.step()
return train_losses, train_accuracies, val_losses, val_accuracies
这是我的笔记本:EfficientNet with Augmentation
我看到的一个问题是您没有使用多个进程来加载训练示例。 如果您的程序花费更多时间加载训练示例(由 CPU 完成)而不是实际训练它们(由 GPU 完成),这可以解释 GPU 利用率低的原因。
更具体地说,每次数据集加载示例时,它都必须解析文件并应用一堆数据增强。 仅通过阅读代码很难确定,但这两个步骤都可能非常昂贵。
如果这确实是问题所在,这里有两种可能的解决方法。首先是使用多个数据加载器进程。这确实很容易做到;只需将
num_workers
参数传递给 DataLoader
构造函数即可。 这种方法的唯一缺点是您需要大量 CPU 才能充分利用它,而云提供商可能不会为您提供太多。 其次是预加载整个数据集。 这只是相对较小的数据集的一个选项,但如果这适用于您,并且您可以缓存结果,这样您就不需要每次都进行完整的预加载,这可能是最快的方法。