我使用Tensorflow制作了UNet来分割图像。我参考了原论文,除了输出通道数(使用3而不是2)之外,我实现了相同的结构。我的模型似乎运行良好,但损失不收敛,就像心电图...
我使用了MSE作为损失函数,但是有些网站说MSE在UNet中无效,所以我改为使用dice loss。但它仍然不起作用。我怀疑网络结构不好。
import tensorflow as tf
tf.reset_default_graph()
with tf.name_scope('input'):
X = tf.placeholder(tf.float32, shape=[None, 572, 572, 3], name='X')
y = tf.placeholder(tf.float32, shape=[None, 388, 388, 3], name='y')
# Encoding
with tf.name_scope('layer1'):
conv1 = tf.layers.conv2d(X, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv1')
conv2 = tf.layers.conv2d(conv1, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv2')
with tf.name_scope('layer2'):
pool1 = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool1')
conv3 = tf.layers.conv2d(pool1, filters=128, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv3')
conv4 = tf.layers.conv2d(conv3, filters=128, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv4')
with tf.name_scope('layer3'):
pool2 = tf.nn.max_pool(conv4, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool2')
conv5 = tf.layers.conv2d(pool2, filters=256, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv5')
conv6 = tf.layers.conv2d(conv5, filters=256, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv6')
with tf.name_scope('layer4'):
pool3 = tf.nn.max_pool(conv6, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool3')
conv7 = tf.layers.conv2d(pool3, filters=512, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv7')
conv8 = tf.layers.conv2d(conv7, filters=512, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv8')
with tf.name_scope('layer5'):
pool4 = tf.nn.max_pool(conv8, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool4')
conv9 = tf.layers.conv2d(pool4, filters=1024, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv9')
conv10 = tf.layers.conv2d(conv9, filters=1024, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv10')
#Decoding
with tf.name_scope('layer6'):
up_conv1 = tf.layers.conv2d_transpose(conv10, filters=512, kernel_size=2, strides=2)
croped_conv8 = tf.image.central_crop(conv8, 7/8)
concat1 = tf.concat([croped_conv8, up_conv1], axis=-1)
conv11 = tf.layers.conv2d(concat1, filters=512, kernel_size=3, activation=tf.nn.relu, name='conv11')
conv12 = tf.layers.conv2d(conv11, filters=512, kernel_size=3, activation=tf.nn.relu, name='conv12')
with tf.name_scope('layer7'):
up_conv2 = tf.layers.conv2d_transpose(conv12, filters=256, kernel_size=2, strides=2)
croped_conv6 = tf.image.central_crop(conv6, 13/17)
concat2 = tf.concat([croped_conv6, up_conv2], axis=-1)
conv13 = tf.layers.conv2d(concat2, filters=256, kernel_size=3, activation=tf.nn.relu, name='conv13')
conv14 = tf.layers.conv2d(conv13, filters=256, kernel_size=3, activation=tf.nn.relu, name='conv14')
with tf.name_scope('layer8'):
up_conv3 = tf.layers.conv2d_transpose(conv14, filters=128, kernel_size=2, strides=2)
croped_conv4 = tf.image.central_crop(conv4, 5/7)
concat3 = tf.concat([croped_conv4, up_conv3], axis=-1)
conv15 = tf.layers.conv2d(concat3, filters=128, kernel_size=3, activation=tf.nn.relu, name='conv15')
conv16 = tf.layers.conv2d(conv15, filters=128, kernel_size=3, activation=tf.nn.relu, name='conv16')
with tf.name_scope('layer8'):
up_conv4 = tf.layers.conv2d_transpose(conv16, filters=64, kernel_size=2, strides=2)
croped_conv2 = tf.image.central_crop(conv2, 49/71)
concat4 = tf.concat([croped_conv2, up_conv4], axis=-1)
conv17 = tf.layers.conv2d(concat4, filters=64, kernel_size=3, activation=tf.nn.relu, name='conv17')
conv18 = tf.layers.conv2d(conv17, filters=64, kernel_size=3, activation=tf.nn.relu, name='conv18')
output = tf.layers.conv2d(conv18, filters=3, kernel_size=1, name='output')
with tf.name_scope('train'):
dice = 2 * tf.math.reduce_sum(output*y) / (tf.math.reduce_sum(output) + tf.math.reduce_sum(y) + 1)
loss = 1 - dice
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss)
with tf.name_scope('save'):
saver = tf.train.Saver()
loss_sumary = tf.summary.scalar('ls', loss)
filewriter = tf.summary.FileWriter('./', tf.get_default_graph())
我知道已经晚了 5 年,但我却陷入了同样的问题。我也直接实施了这篇论文。在我的例子中,数据集是 ISBI-2012,其中每个图像的尺寸为 512x512。我做了以下事情: