我尝试用 pytorch 从头开始制作我的 UNET。我的模型输出除了黑色面具之外什么也没有。我需要对汽车的损坏进行分段,因此我实现了彩色图。我确信我的数据集有 70% 的问题与此颜色图完全一致。任务是多类预测,所以我使用交叉熵损失函数。我将提供我的数据集和训练文件的代码。
# dataset.py
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch
class Segm_Dataset(Dataset):
def __init__(self, image_dir, mask_dir, color_map):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_files = os.listdir(self.image_dir)
self.mask_files = os.listdir(self.mask_dir)
self.color_map = color_map
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = os.path.join(self.image_dir, self.image_files[idx])
mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
image = np.array(Image.open(image_path).convert('RGB'))
mask = np.array(Image.open(mask_path).convert('RGB'), dtype=np.float32)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
for color, label in self.color_map.items():
color_array = np.array(color, dtype=np.float32)
mask_area = np.all(mask == color_array, axis=-1)
label_mask[mask_area] = label
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
label_mask = torch.tensor(label_mask, dtype=torch.long)
return image, label_mask
# train.py
from model import UNET
from tqdm import tqdm
from dataset import Segm_Dataset
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import os
LEARNING_RATE = 1e-4
BATCH_SIZE = 5
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 180
IMAGE_WIDTH = 180
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = r'data\train\images'
TRAIN_MASK_DIR = r'data\train\masks'
VAL_IMG_DIR = r'data\val\images'
VAL_MASK_DIR = r'data\val\masks'
SAVED_MODELS_PATH = r'saved_models'
color_map = {
(19, 164, 201): 0, # Missing part: #13A4C9
(166, 255, 71): 1, # Broken part: #A6FF47
(180, 45, 56): 2, # Scratch: #B42D38
(225, 150, 96): 3, # Cracked: #E19660
(144, 60, 89): 4, # Dent: #903C59
(167, 116, 27): 5, # Flaking: #A7741B
(180, 14, 19): 6, # Paint chip: #B40E13
(115, 194, 206): 7, # Corrosion: #73C2CE
}
train_dataset = Segm_Dataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, color_map)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = Segm_Dataset(VAL_IMG_DIR, VAL_MASK_DIR, color_map)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE)
model = UNET(in_channels=3, out_channels=len(color_map))
model = model.cuda() if torch.cuda.is_available() else model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
for epoch in range(NUM_EPOCHS):
train_loop = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_index, (data, targets) in train_loop:
#Forward pass
scores = model(data)
train_loss = criterion(scores, targets)
#Backward pass
optimizer.zero_grad()
train_loss.backward()
#Gradient descent or optimizer step
optimizer.step()
if batch_index % 10 == 0:
current_batch = batch_index
val_loss = 0
with torch.no_grad():
for val_data, val_targets in val_loader:
val_scores = model(val_data)
val_loss = criterion(val_scores, val_targets)
#Update progress bar
train_loop.set_description(f'Epoch: [{epoch+1}/{NUM_EPOCHS}]')
train_loop.set_postfix(train_loss=train_loss.item(), val_loss=val_loss.item(), val_batch=current_batch)
else:
train_loop.set_description(f'Epoch: [{epoch+1}/{NUM_EPOCHS}]')
train_loop.set_postfix(train_loss=train_loss.item(), val_loss=val_loss.item(), val_batch=current_batch)
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss.item(),
'val_loss': val_loss.item()
}
torch.save(checkpoint, os.path.join(SAVED_MODELS_PATH, f'unet_epoch_{epoch}.pth'))
一些训练时期:
Epoch: [9/10]: 100%|██████████████████| 888/888 [34:24<00:00, 2.32s/it, train_loss=0.000271, val_batch=880, val_loss=0.000278]
Epoch: [10/10]: 100%|█████████████████| 888/888 [34:29<00:00, 2.33s/it, train_loss=0.000163, val_batch=880, val_loss=0.000167]
你尝试过不同的学习率吗? 我也遇到了类似的问题,问题是学习率太小了。
还有一个帖子讨论了类似的问题: 训练uNet模型预测只有黑色
您的图像最终采用什么范围的值?它看起来像是 0 到 256 之间的值。最佳实践是将它们标准化在 (0,1) 或 (-1,1) 之间,因此只需除以 256 即可。
也可以按照建议尝试更大的批量大小/更低的 LR。