我正在尝试使用 Pytorch 中的 UNet 从多维(8 波段)卫星图像中提取预测掩模。我无法让预测蒙版看起来有些预期/连贯。我不确定问题是否在于我的训练数据的格式化方式、我的训练代码或我用于进行预测的代码。我怀疑这是我的训练数据输入模型的方式。我有 8 个波段卫星图像和单波段掩码,其值范围为 0-n 个类别,其中 0 是背景,1-n 是目标标签,如下所示:
在单通道示例的情况下,图像形状为 (8, 512, 512),掩模形状为 (512, 512),在 OHE 情况下为 (512, 512, 8),而 (512, 512, 3)在堆叠的情况下。
有些蒙版可能包含所有类别标签,有些可能只有几个或仅是背景标签。我尝试过使用这些单通道掩码,我还将它们转换为 3 通道掩码,第一个通道是给定图像的所有标签,并且我还尝试了对它们进行热编码,以便每个掩码都是 0- n 维度,每个通道都有不同的标签,带有二进制 0-1 表示背景/目标。
编辑 更改 softmax
dim=2
后,输出开始看起来好一些。然而,在最初的几个预热阶段之后,模型似乎根本没有学习,因为训练损失最初减少,但随后立即趋于稳定或增加,并且预测掩模不再有意义(全黑或随机斑点)。我怀疑我的训练管道(如下)存在问题,或者可能是由于 0 类(背景)的类不平衡造成的。
import os
import torch
import numpy as np
from skimage import io
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp
image_dir = r'test_segmentation\images'
mask_dir = r'test_segmentation\masks'
data_dir=r'unet_training'
os.makedirs(data_dir, exist_ok=True)
model_dir = os.path.join(data_dir, 'models')
os.makedirs(model_dir, exist_ok=True)
pred_dir = os.path.join(data_dir, 'predictions')
os.makedirs(pred_dir, exist_ok=True)
num_bands = 8
num_classes = 9
epochs = 10
learning_rate = 0.001
weight_decay = 0
encoder = 'resnet50'
encoder_weights = 'imagenet'
model = smp.Unet(in_channels=num_bands, encoder_name=encoder, encoder_weights=encoder_weights, classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_function = nn.CrossEntropyLoss() if num_classes > 1 else nn.BCEWithLogitsLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(1, epochs + 1):
train_loss = 0
val_loss = 0
train_loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} Training")
model.train()
for batch_idx, (data, targets) in train_loop:
optimizer.zero_grad()
data = data.float().to(device)
targets = targets.long().to(device)
predictions = model(data)
loss = loss_function(predictions, targets)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_loop.set_postfix(loss=train_loss)
val_loop = tqdm(enumerate(val_loader), total=len(val_loader), desc=f"Epoch {epoch} Validation")
model.eval()
for batch_idx, (data, targets) in val_loop:
data, targets = data.to(device).float(), targets.to(device).long()
preds = model(data)
val_loss = loss_function(preds, targets).item()
softmax = torch.nn.Softmax(dim=2)
preds = torch.argmax(softmax(preds), dim=1).cpu().numpy()
preds = np.array(preds[0, :, :], dtype=np.uint8)
labels = np.array(targets.cpu().numpy()[0, :, :], dtype=np.uint8)
#save prediction and label mask
pred_path = os.path.join(pred_dir, f"{epoch}_{batch_idx}_pred.png")
label_path = os.path.join(pred_dir, f"{epoch}_{batch_idx}_label.png")
io.imsave(pred_path, preds)
io.imsave(label_path, labels)
val_loop.set_postfix(loss=val_loss)
avg_train_loss = train_loss / (batch_idx + 1)
avg_val_loss = val_loss/ (batch_idx + 1)
print(f"\nEpoch {epoch} Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}")
checkpoint_name = os.path.join(model_dir, f"{modeltype}_bands{num_bands}_classes{num_classes}_{encoder}_{learning_rate}_{epoch}.pt")
if epoch == 1:
torch.save(model.state_dict(), checkpoint_name)
elif epoch % 10 == 0:
torch.save(model.state_dict(), checkpoint_name)
elif epoch == epochs:
torch.save(model.state_dict(), checkpoint_name)
else:
pass
在开始的训练循环中
for batch_idx, (data, targets) in train_loop:
,检查以下内容:
softmax = torch.nn.Softmax(dim=1)
以便在通道维度上进行 softmax dim=1
targets.shape
应该是(batch, 512, 512)
targets
应该是一个范围为[0, n_classes - 1]
的整数,表示类别data.shape
应该是(batch, channels, 512, 512)
preds.shape
应该是(batch, 512, 512)
原回复:
火车损失有下降吗?在每个时期打印出来很有用。
尝试仅拟合一张图像或一小批图像 - 继续运行它,直到损失进一步下降。合理的输出应该开始出现,就像大致在正确位置的斑点一样。如果没有,则表明管道在某个地方损坏了,因为它根本无法学习
最初使用下采样图像并将掩模限制为单个通道也可能是值得的。它们将帮助网络更快地收敛并突出收敛问题。