训练uNet模型预测只有黑色

问题描述 投票:0回答:2

我正在训练用于分割的 uNet 模型。训练模型后,输出全为零,我不明白为什么。

我看到建议我应该使用特定的损失函数,所以我使用了骰子损失函数。这是因为黑色区域 (0) 比白色区域 (1) 大得多。

我做错了什么吗?

我的型号是:

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 80, 80, 1)    0
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 80, 80, 64)   640         input_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 80, 80, 64)   36928       conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 40, 40, 64)   0           conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 40, 40, 128)  73856       max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 40, 40, 128)  147584      conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 20, 20, 128)  0           conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 20, 20, 256)  295168      max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 20, 20, 256)  590080      conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 10, 10, 256)  0           conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 10, 10, 512)  1180160     max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 10, 10, 512)  2359808     conv2d_7[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 10, 10, 512)  0           conv2d_8[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 5, 5, 512)    0           dropout_1[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 5, 5, 1024)   4719616     max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 5, 5, 1024)   9438208     conv2d_9[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 5, 5, 1024)   0           conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 10, 10, 512)  2097664     dropout_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 10, 10, 1024) 0           dropout_1[0][0]
                                                                 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 10, 10, 512)  4719104     concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 10, 10, 512)  2359808     conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 20, 20, 256)  524544      conv2d_12[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 20, 20, 512)  0           conv2d_6[0][0]
                                                                 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 20, 20, 256)  1179904     concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 20, 20, 256)  590080      conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 40, 40, 128)  131200      conv2d_14[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 40, 40, 256)  0           conv2d_4[0][0]
                                                                 conv2d_transpose_3[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 40, 40, 128)  295040      concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 40, 40, 128)  147584      conv2d_15[0][0]
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 80, 80, 64)   32832       conv2d_16[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 80, 80, 128)  0           conv2d_2[0][0]
                                                                 conv2d_transpose_4[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 80, 80, 64)   73792       concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 80, 80, 64)   36928       conv2d_17[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 80, 80, 2)    1154        conv2d_18[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 80, 80, 1)    3           conv2d_19[0][0]
==================================================================================================

损失函数

def dice_loss_v2(y_true, y_pred):
    numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1,2,3))
    denominator = tf.reduce_sum(y_true + y_pred, axis=(1,2,3))

    return 1 - numerator / denominator

激活

    model.compile(optimizer='adam',
                  loss=dice_loss_v2,
                  metrics=['accuracy', iou_loss_core])

预定义学习率为LR=0.001

额外信息:

        datagen = ImageDataGenerator(
                    rotation_range=10, 
                    width_shift_range=0.1, 
                    height_shift_range=0.1, 
                    zoom_range=0.1)
        datagen.fit(X_train)

        model.fit_generator(datagen.flow(X_train, y_train, batch_size=100), steps_per_epoch=len(X_train), 
                                        epochs=4, validation_data=(X_test, y_test))
python python-3.x keras conv-neural-network training-data
2个回答
0
投票

由于像素预测为 0 和 1,遮罩可能全黑。0 和 1 在颜色空间 (0,255) 中都接近黑色。因此,将掩码乘以 255,然后尝试保存/显示它(所有 1 都转换为 255)。您将获得所需的输出。


0
投票

留意模型的输出。

在我的例子中,输出采用

nn.tanh()
,因此输出介于 -1 和 1 之间。 为了显示图片,我需要 (output + 1) / 2。

© www.soinside.com 2019 - 2024. All rights reserved.