我目前正在使用 U-Net 开发一种去噪算法,当我在我的数据上训练它时,MSE 并没有降低。
这是我正在处理的图像的形状:
x_train_noisy.shape: (288, 256, 256, 3)
x_train.shape: (288, 256, 256, 3)
x_val_noisy.shape: (32, 256, 256, 3)
x_val.shape: (32, 256, 256, 3)
这是我的神经网络:
def unet(input_shape):
#Encoder
inputs = Input(shape=input_shape) # Define input tensor with the given shape
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) # Convolutional layer 1
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1) # Convolutional layer 2
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) # Max pooling layer 1
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) # Convolutional layer 3
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2) # Convolutional layer 4
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) # Max pooling layer 2
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) # Convolutional layer 5
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3) # Convolutional layer 6
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) # Max pooling layer 3
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) # Convolutional layer 7
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4) # Convolutional layer 8
drop4 = Dropout(0.5)(conv4) # Dropout layer
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) # Max pooling layer 4
# Bottom
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) # Convolutional layer 9
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5) # Convolutional layer 10
drop5 = Dropout(0.5)(conv5) # Dropout layer
#Decoder
up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(drop5)) # Upsampling and convolutional layer 11
merge6 = concatenate([drop4, up6], axis=3) # Concatenate the features from the contracting and expansive paths
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) # Convolutional layer 12
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6) # Convolutional layer 13
up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv6)) # Upsampling and convolutional layer 14
merge7 = concatenate([conv3, up7], axis=3) # Concatenate the features from the contracting and expansive paths
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) # Convolutional layer 15
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7) # Convolutional layer 16
up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv7)) # Upsampling and convolutional layer 17
merge8 = concatenate([conv2, up8], axis=3) # Concatenate the features from the contracting and expansive paths
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) # Convolutional layer 18
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8) # Convolutional layer 19
up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv8)) # Upsampling and convolutional layer 20
merge9 = concatenate([conv1, up9], axis=3) # Concatenate the features from the contracting and expansive paths
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) # Convolutional layer 21
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9) # Convolutional layer 22
# Output layer
outputs = Conv2D(3, 3, activation='sigmoid', padding='same')(conv9) # Assuming RGB images
model = Model(inputs=inputs, outputs=outputs) # Create the model with input and output layers
return model # Return the U-Net model
当我训练神经网络时:
input_shape = x_train.shape[1:]
model = unet(input_shape)
learning_rate = 0.001
optimizer = Adam(learning_rate = learning_rate)
model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['mse'])
t0 = time.time()
u_history = model.fit(x_train_noisy, x_train, batch_size=32, epochs=100, validation_data=(x_val_noisy, x_val))
t1 = time.time()
training_time_in_seconds = t1-t0
print(f"U-Net training time: {training_time_in_seconds} seconds")
MSE 没有改变:
Epoch 1/100
9/9 [==============================] - 77s 2s/step - loss: 0.0334 - mse: 0.0334 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 2/100
9/9 [==============================] - 20s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 3/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 4/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 5/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 6/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 7/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 8/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 9/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 10/100
9/9 [==============================] - 21s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 11/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 12/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 13/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 14/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 15/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 16/100
9/9 [==============================] - 22s 3s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 17/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 18/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 19/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 20/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 21/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 22/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 23/100
9/9 [==============================] - 23s 3s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 24/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 25/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 26/100
9/9 [==============================] - 23s 3s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
Epoch 27/100
9/9 [==============================] - 22s 2s/step - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0211 - val_mse: 0.0211
我已将学习率降低至 0.0001,但这也没有帮助。