model.fit() 使用压缩数据集抛出错误

问题描述 投票:0回答:1

我正在研究分割模型并使用 ImageDataGenerator() 加载和增强图像及其蒙版。我将它们作为压缩对象返回

def create_segmentation_generator_train(img_path, msk_path, BATCH_SIZE):
    data_gen_args = #augmentation code
    datagen = ImageDataGenerator(**data_gen_args)
    
    img_generator = datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
    msk_generator = datagen.flow_from_directory(msk_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
    return zip(img_generator, msk_generator)

# Remember not to perform any image augmentation in the test generator!
def create_segmentation_generator_test(img_path, msk_path, BATCH_SIZE):
    data_gen_args = dict(rescale=1./255)
    datagen = ImageDataGenerator(**data_gen_args)
    
    img_generator = datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
    msk_generator = datagen.flow_from_directory(msk_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
    return zip(img_generator, msk_generator)

train_generator = create_segmentation_generator_train(data_dir_train_image, data_dir_train_mask, BATCH_SIZE_TRAIN)
test_generator = create_segmentation_generator_test(data_dir_test_image, data_dir_test_mask, BATCH_SIZE_TEST)

但是,当我将生成器传递给 model.fit() 时,它会抛出一个错误,指出我正在传递无法识别的数据类型

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[12], line 1
----> 1 model.fit(x=train_generator, 
      2                     steps_per_epoch=EPOCH_STEP_TRAIN, 
      3                     validation_data=test_generator, 
      4                     validation_steps=EPOCH_STEP_TEST,
      5                    epochs=NUM_OF_EPOCHS)

File c:\Users\biome\anaconda3\envs\TCIA\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File c:\Users\biome\anaconda3\envs\TCIA\Lib\site-packages\keras\src\trainers\data_adapters\__init__.py:120, in get_data_adapter(x, y, sample_weight, batch_size, steps_per_epoch, shuffle, class_weight)
    112     return GeneratorDataAdapter(x)
    113     # TODO: should we warn or not?
    114     # warnings.warn(
    115     #     "`shuffle=True` was passed, but will be ignored since the "
   (...)
    118     # )
    119 else:
--> 120     raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})")

ValueError: Unrecognized data type: x=<zip object at 0x00000172364F6040> (of type <class 'zip'>)

我该如何修复这个错误?

python tensorflow keras error-handling
1个回答
0
投票

要压缩

tf.data.Dataset
对象,您应该使用 tf.data.Dataset.zip,而不是本机 python 函数。

return tf.data.Dataset.zip(img_generator, msk_generator)
© www.soinside.com 2019 - 2024. All rights reserved.