我正在训练一个面具r-cnn模型参考github上的这个代表:https://github.com/matterport/Mask_RCNN
我遇到了一个似乎是使用Keras问题的问题,所以我来到这里。
代码计算感兴趣区域(rois)和特征映射的掩码:
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)
但是,有时rois可能全为零,在这种情况下我想直接返回所有零。所以,我像这样使用tf.cond:
def ff_true():
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)
def ff_false():
return tf.zeros_like(target_mask)
mrcnn_mask = KL.Lambda(lambda x: tf.cond(tf.equal(tf.reduce_mean(x), 0),
ff_true, ff_true)) (rois)
这引发了一个错误:
ValueError:变量lambda_5 / cond / mrcnn_mask_conv1 / kernel /的初始化器来自控制流构造内部,例如循环或条件。在循环或条件内创建变量时,使用lambda作为初始化程序。
我谷歌它但没有有用的信息。这似乎是错误地使用keras / tensorflow的问题。任何线索都会受到欢迎!
顺便说一句,如果我使用这个代码,它将没有错误(但我不想提前计算):
a = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)
def ff_true():
return a
def ff_false():
return tf.zeros_like(target_mask)
mrcnn_mask = KL.Lambda(lambda x: tf.cond(tf.equal(tf.reduce_mean(x), 0),
ff_true, ff_true)) (rois)
错误基本上就是消息所说的。您不能在条件中包含变量初始值设定项。与普通编程语言的粗略类比是:
if my_condition:
a = 1
print a # can't do this. a might be uninitialized.
下面是一个简单的示例来说明此问题以及错误消息中建议的修复:
import tensorflow as tf
def cond(i, _):
return i < 10
def body(i, _):
zero = tf.zeros([], dtype=tf.int32)
v = tf.Variable(initial_value=zero)
return (i + 1, v.read_value())
def body_ok(i, _):
zero = lambda: tf.zeros([], dtype=tf.int32)
v = tf.Variable(initial_value=zero, dtype=tf.int32)
return (i + 1, v.read_value())
tf.while_loop(cond, body, [0, 0])
这是使用tf.while_loop
但是为了这个目的它与tf.cond
相同。如果按原样运行此代码,您将收到相同的错误。如果用body
替换body_ok
,一切都会好的。原因是当初始化程序是一个函数时,张量流可以将其置于“控制流上下文之外”以确保它始终运行。
为了澄清未来读者可能存在的混淆,“首先计算a
”的方法并不理想,而是出于微妙的原因。首先,请记住,您在此处所做的是构建计算图(假设您没有使用eager execution)。所以,你实际上并没有计算a
。您只是定义如何计算它。 Tensorflow运行时决定在运行时需要计算的内容,具体取决于session.run()
的参数。因此,可以预期如果条件为假,则不会执行返回a
的分支(因为它不需要)。不幸的是,这不是TensorFlow运行时的工作原理。您可以在第一个答案here中找到更多详细信息,但简单地说,TensorFlow运行时将执行任一分支的所有依赖关系,只有true_fn/false_fn
内的操作才会有条件地执行。
使用带有CNN-LSTM的keras时,我也会遇到同样的问题。代码在GPU服务器上工作正常,但是当我尝试在我的本地机器上运行时,得到了这个奇怪的错误。
以下技巧对我有用。
解决方案:清除变量并重新启动内核。这对我有用。也许其他人遇到完全相同的问题,我将会有所帮助。