我正在尝试使用 pytorch 训练 U-Net 模型构建。对于这种情况,我构建了数据集并在图像和掩模中应用了数据增强转换。情况是我想对两者应用相同的变换,这意味着,如果我将图像旋转一定的度数,我希望蒙版旋转相同的度数,这就是我的问题。图像和蒙版的旋转量不同。
我留下以下代码:
数据集
import torch
from torch.utils.data import Dataset
import os
class INBreastDataset2012(Dataset):
def __init__(self, dict_dir, transform=None):
self.dict_dir = dict_dir
self.data = os.listdir(self.dict_dir)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
dict_path = os.path.join(self.dict_dir, self.data[index])
patient_dict = torch.load(dict_path)
image = patient_dict['image'].unsqueeze(0)
mass_mask = patient_dict['mass_mask'].unsqueeze(0)
mass_mask[mass_mask > 1.0] = 1.0
if self.transform is not None:
image = self.transform(image)
mass_mask = self.transform(mass_mask)
return image, mass_mask
“训练”(此时并不是真正的训练,只是数据加载器带来的信息的可视化)
from dataset import INBreastDataset2012
from torchvision.transforms import v2 as T
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
train_dir = r'directory\of\training images and masks'
test_dir = r'directory\of\testing images and masks'
train_transform = T.Compose(
[
T.RandomRotation(degrees=35, expand=True, fill=255.0),
T.RandomHorizontalFlip(p=0.5),
T.RandomVerticalFlip(p=0.5),
]
)
train_data = INBreastDataset2012(train_dir,transform=train_transform)
test_data = INBreastDataset2012(test_dir)
train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)
plt.figure(figsize=(12,12))
for i, (imagen,mascara) in enumerate(train_dataloader):
ax = plt.subplot(2,4,i+1)
ax.title.set_text(f'imagen {i+1}')
plt.imshow(imagen.squeeze(), cmap='gray')
ax = plt.subplot(2,4,i+3)
ax.title.set_text(f'mascara de imagen {i+1}')
plt.imshow(mascara.squeeze(), cmap='gray')
if i == 1:
break
我还要补充一点,我已经尝试过使用 albumentations 和 torchvision.transforms v1。在 pytorch 和 youtube 视频的示例中,他们似乎做了和我一样的事情。
有人可以帮助我看看我做错了什么,或者有一个解决方案来确保转换相同,我将不胜感激。
如果需要任何额外信息,请询问。这是我的第一篇文章,所以我可能错过了一些东西。 先谢谢你了
您可以尝试沿通道维度连接图像和掩模,运行变换,然后将结果拆分回两个张量。
...
if self.transform is not None:
#Concatenate along channel dimension
image_and_mask = torch.cat([image, mask], dim=1)
#Transform together
transformed = self.transform(image_and_mask)
#Slice the tensors out
image = transformed[:, :image.shape[1], ...]
mass_mask = transformed[:, image.shape[1]:, ...]
...