我正在研究分割模型并使用 ImageDataGenerator() 加载和增强图像及其蒙版。我将它们作为压缩对象返回
def create_segmentation_generator_train(img_path, msk_path, BATCH_SIZE):
data_gen_args = #augmentation code
datagen = ImageDataGenerator(**data_gen_args)
img_generator = datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
msk_generator = datagen.flow_from_directory(msk_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
return zip(img_generator, msk_generator)
# Remember not to perform any image augmentation in the test generator!
def create_segmentation_generator_test(img_path, msk_path, BATCH_SIZE):
data_gen_args = dict(rescale=1./255)
datagen = ImageDataGenerator(**data_gen_args)
img_generator = datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
msk_generator = datagen.flow_from_directory(msk_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED)
return zip(img_generator, msk_generator)
train_generator = create_segmentation_generator_train(data_dir_train_image, data_dir_train_mask, BATCH_SIZE_TRAIN)
test_generator = create_segmentation_generator_test(data_dir_test_image, data_dir_test_mask, BATCH_SIZE_TEST)
但是,当我将生成器传递给 model.fit() 时,它会抛出一个错误,指出我正在传递无法识别的数据类型
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[12], line 1
----> 1 model.fit(x=train_generator,
2 steps_per_epoch=EPOCH_STEP_TRAIN,
3 validation_data=test_generator,
4 validation_steps=EPOCH_STEP_TEST,
5 epochs=NUM_OF_EPOCHS)
File c:\Users\biome\anaconda3\envs\TCIA\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File c:\Users\biome\anaconda3\envs\TCIA\Lib\site-packages\keras\src\trainers\data_adapters\__init__.py:120, in get_data_adapter(x, y, sample_weight, batch_size, steps_per_epoch, shuffle, class_weight)
112 return GeneratorDataAdapter(x)
113 # TODO: should we warn or not?
114 # warnings.warn(
115 # "`shuffle=True` was passed, but will be ignored since the "
(...)
118 # )
119 else:
--> 120 raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})")
ValueError: Unrecognized data type: x=<zip object at 0x00000172364F6040> (of type <class 'zip'>)
我该如何修复这个错误?
要压缩
tf.data.Dataset
对象,您应该使用 tf.data.Dataset.zip,而不是本机 python 函数。
return tf.data.Dataset.zip(img_generator, msk_generator)