我正在构建自己的数据集:
class MyDataset(Dataset):
def __init__(self, folders):
self.folders = folders
def __len__(self):
return len(self.folders)
def __getitem__(self, item):
pos_file_list = glob(self.folders[item] + "/*")
positive_img = pos_file_list[1]
positive_img = mpimg.imread(positive_img)
positive_img = np.transpose(positive_img, (2,0,1))
# positive_img have the type: <class 'numpy.ndarray'>, shape: (3, 128, 128)
return positive_img
我正在使用它:
batch_size = 128
train_ds = MyDataset(train_folder_list)
oTrainDL = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
for i, imgs in enumerate(oTrainDL):
break
我收到以下警告:
UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:189.)
return default_collate([torch.as_tensor(b) for b in batch])
为什么我会收到警告消息?我该如何解决它?
从
return positive_img
更改为:
return torch.tensor(positive_img)