我正在寻找一种在张量流中有条件打印节点的方法,使用下面的代码示例行,其中每10个循环计数,它应该在控制台中打印一些内容。但这对我不起作用。有人可以建议吗?
谢谢,Hamidreza,
epsilon = tf.cond(tf.constant(counter % 10 == 0, dtype=tf.bool), true_fn=lambda:tf.Print(epsilon, [counter, epsilon], 'batch: ', summarize=10), false_fn=lambda:epsilon)
我遇到了类似的问题,并通过以下难看的解决方案解决了它。
由于解密,我使用的是tf.print
而不是'tf.Print':
由于true_func和false_func应该返回相同的类型和形状,所以我返回了无意义的const。
def true_func():
printfunc = tf.print(TENSOR_TO_PRINT,summarize=-1)
with tf.control_dependencies([printfunc]):
return tf.constant(1)
def false_func():
return tf.constant(1)
shouldprint = tf.cond(YOUR_CONDITION ,true_fn = true_func ,false_fn=false_func)
然后您可以运行shouldprint
操作或将其添加为其他操作的依赖项。