我正在训练基于U-net的分割网络,并使用keras的ImageDataGenerator
在线增强我的灰度图像。除非我在参数中包含brightness_range
,否则一切都会按预期进行。发生这种情况时,我的512,512,1映像似乎变成了512,512映像,并弄乱了一切。我该如何解决?
这是我的扩充代码:
data_gen_args = dict(
rotation_range=15,
shear_range=45,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=[0.5,1.5],
#horizontal_flip=True,
#vertical_flip=True,
brightness_range=[0.5,1.5],
fill_mode='nearest'
)
image_datagen_train = ImageDataGenerator(**data_gen_args)
train_image_generator = image_datagen_train.flow_from_directory(
train_ct,
target_size = (512, 512),
color_mode = ("grayscale"),
classes=None,
class_mode=None,
seed = seed,
batch_size = BS)
train_mask_generator = mask_datagen_train.flow_from_directory(
train_mask,
target_size = (512, 512),
color_mode = ("grayscale"),
classes=None,
class_mode=None,
seed = seed,
batch_size = BS)
这是我的错误信息:
ValueError Traceback (most recent call last)
<ipython-input-33-8d701f27a3fa> in <module>
7 verbose=1,
8 callbacks=cb_check,
----> 9 use_multiprocessing = False
10 )
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1730 use_multiprocessing=use_multiprocessing,
1731 shuffle=shuffle,
-> 1732 initial_epoch=initial_epoch)
1733
1734 @interfaces.legacy_generator_methods_support
/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
183 batch_index = 0
184 while steps_done < steps_per_epoch:
--> 185 generator_output = next(output_generator)
186
187 if not hasattr(generator_output, '__len__'):
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
740 "`use_multiprocessing=False, workers > 1`."
741 "For more information see issue #1638.")
--> 742 six.reraise(*sys.exc_info())
/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
709 try:
710 future = self.queue.get(block=True)
--> 711 inputs = future.get(timeout=30)
712 self.queue.task_done()
713 except mp.TimeoutError:
/usr/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
642 return self._value
643 else:
--> 644 raise self._value
645
646 def _set(self, i, obj):
/usr/lib/python3.6/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
117 job, i, func, args, kwds = task
118 try:
--> 119 result = (True, func(*args, **kwds))
120 except Exception as e:
121 if wrap_exception and func is not _helper_reraises_exception:
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in next_sample(uid)
648 The next value of generator `uid`.
649 """
--> 650 return six.next(_SHARED_SEQUENCES[uid])
651
652
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in __next__(self, *args, **kwargs)
102
103 def __next__(self, *args, **kwargs):
--> 104 return self.next(*args, **kwargs)
105
106 def next(self):
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in next(self)
114 # The transformation of images is not under thread lock
115 # so it can be done in parallel
--> 116 return self._get_batches_of_transformed_samples(index_array)
117
118 def _get_batches_of_transformed_samples(self, index_array):
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in _get_batches_of_transformed_samples(self, index_array)
244 x = self.image_data_generator.apply_transform(x, params)
245 x = self.image_data_generator.standardize(x)
--> 246 batch_x[i] = x
247 # optionally save augmented images to disk for debugging purposes
248 if self.save_to_dir:
ValueError: could not broadcast input array from shape (512,512) into shape (512,512,1)
这里是我在ImageDataGenerator中使用Brightness_range的模型,它没有引起任何问题。该模型运行良好。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
import os
import numpy as np
import matplotlib.pyplot as plt
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats') # directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs') # directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats') # directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs') # directory with our validation dog pictures
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
total_train = num_cats_tr + num_dogs_tr
total_val = num_cats_val + num_dogs_val
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
train_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our training data
validation_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our validation data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1)
])
model.compile(optimizer="adam",
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit_generator(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size)
输出-
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
WARNING:tensorflow:From <ipython-input-2-2b8537e7d5b3>:74: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
Epoch 1/15
15/15 [==============================] - 9s 591ms/step - loss: 1.0527 - accuracy: 0.5010 - val_loss: 0.6918 - val_accuracy: 0.5089
Epoch 2/15
15/15 [==============================] - 9s 609ms/step - loss: 0.6790 - accuracy: 0.5337 - val_loss: 0.6473 - val_accuracy: 0.5647
Epoch 3/15
15/15 [==============================] - 9s 610ms/step - loss: 0.6340 - accuracy: 0.5983 - val_loss: 0.6208 - val_accuracy: 0.6172
Epoch 4/15
15/15 [==============================] - 9s 609ms/step - loss: 0.5899 - accuracy: 0.6464 - val_loss: 0.5938 - val_accuracy: 0.6585
Epoch 5/15
15/15 [==============================] - 9s 599ms/step - loss: 0.5182 - accuracy: 0.7286 - val_loss: 0.6165 - val_accuracy: 0.7042
Epoch 6/15
15/15 [==============================] - 9s 608ms/step - loss: 0.4697 - accuracy: 0.7682 - val_loss: 0.5853 - val_accuracy: 0.7109
Epoch 7/15
15/15 [==============================] - 9s 604ms/step - loss: 0.4393 - accuracy: 0.7746 - val_loss: 0.5826 - val_accuracy: 0.7132
Epoch 8/15
15/15 [==============================] - 9s 608ms/step - loss: 0.4115 - accuracy: 0.7895 - val_loss: 0.6602 - val_accuracy: 0.7042
Epoch 9/15
15/15 [==============================] - 9s 598ms/step - loss: 0.3831 - accuracy: 0.8162 - val_loss: 0.6254 - val_accuracy: 0.7076
Epoch 10/15
15/15 [==============================] - 9s 601ms/step - loss: 0.3151 - accuracy: 0.8531 - val_loss: 0.5924 - val_accuracy: 0.7098
Epoch 11/15
15/15 [==============================] - 9s 611ms/step - loss: 0.2904 - accuracy: 0.8632 - val_loss: 0.6664 - val_accuracy: 0.6964
Epoch 12/15
15/15 [==============================] - 9s 604ms/step - loss: 0.2524 - accuracy: 0.8921 - val_loss: 0.7111 - val_accuracy: 0.6752
Epoch 13/15
15/15 [==============================] - 9s 592ms/step - loss: 0.2143 - accuracy: 0.9081 - val_loss: 0.7246 - val_accuracy: 0.6953
Epoch 14/15
15/15 [==============================] - 9s 599ms/step - loss: 0.1829 - accuracy: 0.9284 - val_loss: 0.7323 - val_accuracy: 0.7221
Epoch 15/15
15/15 [==============================] - 9s 598ms/step - loss: 0.1469 - accuracy: 0.9460 - val_loss: 0.8435 - val_accuracy: 0.6998