具有 n 维卫星图像的多类 UNet

问题描述 投票:0回答:1

我正在尝试使用 Pytorch 中的 UNet 从多维(8 波段)卫星图像中提取预测掩模。我无法让预测蒙版看起来有些预期/连贯。我不确定问题是否在于我的训练数据的格式化方式、我的训练代码或我用于进行预测的代码。我怀疑这是我的训练数据输入模型的方式。我有 8 个波段卫星图像和单波段掩码,其值范围为 0-n 个类别,其中 0 是背景,1-n 是目标标签,如下所示:

enter image description here

在单通道示例的情况下,图像形状为 (8, 512, 512),掩模形状为 (512, 512),在 OHE 情况下为 (512, 512, 8),而 (512, 512, 3)在堆叠的情况下。

有些蒙版可能包含所有类别标签,有些可能只有几个或仅是背景标签。我尝试过使用这些单通道掩码,我还将它们转换为 3 通道掩码,第一个通道是给定图像的所有标签,并且我还尝试了对它们进行热编码,以便每个掩码都是 0- n 维度,每个通道都有不同的标签,带有二进制 0-1 表示背景/目标。

编辑 更改 softmax

dim=2
后,输出开始看起来好一些。然而,在最初的几个预热阶段之后,模型似乎根本没有学习,因为训练损失最初减少,但随后立即趋于稳定或增加,并且预测掩模不再有意义(全黑或随机斑点)。我怀疑我的训练管道(如下)存在问题,或者可能是由于 0 类(背景)的类不平衡造成的。

import os
import torch
import numpy as np
from skimage import io
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp

image_dir = r'test_segmentation\images'
mask_dir = r'test_segmentation\masks'

data_dir=r'unet_training'
os.makedirs(data_dir, exist_ok=True)

model_dir = os.path.join(data_dir, 'models')
os.makedirs(model_dir, exist_ok=True)

pred_dir = os.path.join(data_dir, 'predictions')
os.makedirs(pred_dir, exist_ok=True)

num_bands = 8
num_classes = 9
epochs = 10
learning_rate = 0.001
weight_decay = 0
encoder = 'resnet50'
encoder_weights = 'imagenet'

model = smp.Unet(in_channels=num_bands, encoder_name=encoder, encoder_weights=encoder_weights, classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_function = nn.CrossEntropyLoss() if num_classes > 1 else nn.BCEWithLogitsLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(1, epochs + 1):
    train_loss = 0
    val_loss = 0 

    train_loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} Training")
    
    model.train()
    
    for batch_idx, (data, targets) in train_loop:
        optimizer.zero_grad()
        
        data = data.float().to(device)
        targets = targets.long().to(device)
        predictions = model(data)
        loss = loss_function(predictions, targets)
        
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
    
        train_loop.set_postfix(loss=train_loss)

    val_loop = tqdm(enumerate(val_loader), total=len(val_loader), desc=f"Epoch {epoch} Validation")

    model.eval()
    
    for batch_idx, (data, targets) in val_loop:
        data, targets = data.to(device).float(), targets.to(device).long()
        preds = model(data)

        val_loss = loss_function(preds, targets).item()

        softmax = torch.nn.Softmax(dim=2)
        preds = torch.argmax(softmax(preds), dim=1).cpu().numpy()
        preds = np.array(preds[0, :, :], dtype=np.uint8)
        labels = np.array(targets.cpu().numpy()[0, :, :], dtype=np.uint8)

        #save prediction and label mask
        pred_path = os.path.join(pred_dir, f"{epoch}_{batch_idx}_pred.png")
        label_path = os.path.join(pred_dir, f"{epoch}_{batch_idx}_label.png")
        io.imsave(pred_path, preds)
        io.imsave(label_path, labels)

        val_loop.set_postfix(loss=val_loss)
    
    avg_train_loss = train_loss / (batch_idx + 1)
    avg_val_loss = val_loss/ (batch_idx + 1)

    print(f"\nEpoch {epoch} Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}")

    checkpoint_name = os.path.join(model_dir, f"{modeltype}_bands{num_bands}_classes{num_classes}_{encoder}_{learning_rate}_{epoch}.pt")
    
    if epoch == 1:
        torch.save(model.state_dict(), checkpoint_name)
    elif epoch % 10 == 0:
        torch.save(model.state_dict(), checkpoint_name)
    elif epoch == epochs:
        torch.save(model.state_dict(), checkpoint_name)
    else:
        pass
python machine-learning deep-learning pytorch semantic-segmentation
1个回答
0
投票

在开始的训练循环中

for batch_idx, (data, targets) in train_loop:
,检查以下内容:

  • 更改为
    softmax = torch.nn.Softmax(dim=1)
    以便在通道维度上进行 softmax
    dim=1
  • targets.shape
    应该是
    (batch, 512, 512)
    • 在每个像素处,
      targets
      应该是一个范围为
      [0, n_classes - 1]
      的整数,表示类别
  • data.shape
    应该是
    (batch, channels, 512, 512)
  • preds.shape
    应该是
    (batch, 512, 512)

原回复:

火车损失有下降吗?在每个时期打印出来很有用。

尝试仅拟合一张图像或一小批图像 - 继续运行它,直到损失进一步下降。合理的输出应该开始出现,就像大致在正确位置的斑点一样。如果没有,则表明管道在某个地方损坏了,因为它根本无法学习

最初使用下采样图像并将掩模限制为单个通道也可能是值得的。它们将帮助网络更快地收敛并突出收敛问题。

© www.soinside.com 2019 - 2024. All rights reserved.