无法使用tensorflow.keras.preprocessing.image import ImageDataGenerator 来增强图像

问题描述 投票:0回答:1

我想制作我的数据集的增强版本。该数据集是手动标记并从此 RIWA 数据集下载的。

我用这段代码创建类

source_dir = r'./river-water-segmentation-dataset/riwa_v2'
subdir = os.listdir(source_dir)

filepaths = []
labels = []

for i in subdir:
    classpath = os.path.join(source_dir, i)

    if os.path.isdir(classpath):
        file_list = os.listdir(classpath)
        for f in file_list:
            file_path = os.path.join(classpath, f)
            filepaths.append(file_path)
            labels.append(i)
paths = pd.Series(filepaths, name='paths')
labels = pd.Series(labels, name='labels')

df = pd.concat([paths, labels], axis=1)

print(df.head())
print("========================")
print(df['labels'].value_counts())
print("=========================")
print('Total data: ', len(df))

然后将它们设置为每个 700 个作为初始值,以后可能会增加以获得更大的数据集

sample_list = []
max_size = 1500# TODO: change this value

grouping = df.groupby('labels')

for label in df['labels'].unique():
    group = grouping.get_group(label)
    group_size = len(group)

    if group_size > max_size:
        samples = group.sample(max_size, replace=False, weights=None, axis=0).reset_index(drop=True)
    else:
        samples = group.sample(frac=1.0, replace=False, axis=0).reset_index(drop=True)
    sample_list.append(samples)

df = pd.concat(sample_list, axis=0).reset_index(drop=True)
print(df['labels'].value_counts())
print('Total data: ', len(df))

从那里我用这些创建增强数据集

import os
import shutil
from tensorflow.keras.preprocessing.image import ImageDataGenerator

working_dir = r'./river-water-segmentation-dataset/riwa_v2/cropped'

aug_dir = os.path.join(working_dir, 'aug')
if os.path.isdir(aug_dir):
    shutil.rmtree(aug_dir)
os.mkdir(aug_dir)
for label in df['labels'].unique():
    dir_path=os.path.join(aug_dir, label)
    os.mkdir(dir_path)
print(os.listdir(aug_dir))

target = 700 # set the target count for each class in df
gen = ImageDataGenerator(
    rotation_range = 90,
    horizontal_flip = True,
    vertical_flip = True,
)

grouping = df.groupby('labels') # group by class


for label in df['labels'].unique(): # for every class
    group = grouping.get_group(label) # a dataframe holding only rows with the specificied label
    sample_count = len(group) # determine how many samples there are in this class
    # if group.empty:
    #     print(f"No images found for label '{label}'. Skipping augmentation.")
    #     continue
    if sample_count < target: # if the class has less than target number of images
        aug_img_count = 0
        delta = target - sample_count # number of augmented images to create
        target_dir = os.path.join(aug_dir, label) # define where to write the images

        aug_gen = gen.flow_from_dataframe(
            group,
            x_col = 'paths',
            y_col = None,
            target_size = (1420, 1080), # change this target size based on transfer learning model
            class_mode = None,
            batch_size = 1,
            shuffle = False,
            save_to_dir = target_dir,
            save_prefix = 'aug-',
            save_format='jpg'
        )
        images = next(aug_gen)  # Try fetching a batch
        print(f"Generated {len(images)} images.")

        while aug_img_count < delta:
            images = next(aug_gen)
            aug_img_count += len(images)
            

首先,导入

from tensorflow.keras.preprocessing.image import ImageDataGenerator
实际上是从其他答案中的
from tensorflow.preprocessing.image import ImageDataGenerator
移动的,我再也找不到了,因为对于这个版本的keras,它已移至此
tensorflow.keras.preprocessing.image
导入路线。

运行代码超过 10 分钟,结果仍然是

Found 0 validated image filenames
,是我做错了什么吗?这是因为我下载了cpu版本的tensorflow吗?

编辑1:我确实认为这是因为尺寸问题,所以我将所有数据集图像裁剪为相同的尺寸,但代码仍然无法工作。

python tensorflow keras
1个回答
0
投票

问题出在 Riwa 数据集上,其中包含图像和蒙版 单独的文件夹。我把这两个文件夹(图像和蒙版)移到了火车下面 文件夹,以便训练文件夹现在包含图像和蒙版。后 使用

ImageDataGenerator.flow_from_directory
方法,我做了这些 调整后,代码就可以工作了。请参考这个要点

© www.soinside.com 2019 - 2024. All rights reserved.