我想在tensorflow中使用tf.where
函数。
selected_data = tf.where(mask,some_place_holder,zeros)
但是,当我写作时
zeros = tf.zeros(some_place_holder.shape)
发生错误:
ValueError: Cannot convert a partially known TensorShape to a Tensor: (?, 1000, 10)
我也尝试使用tf.fill
,但发生了类似的错误。
嗯,确实有一些解决方案,比如
zeros = tf.matmul(some_place_holder , tf.zeros([some_place_holder.shape[-1],some_place_holder.shape[-1]]))
但有更好的解决方案吗?
你可以使用tf.zeros_like(some_place_holder)
:
input_tensor = tf.placeholder(tf.int8, shape=[None, 3])
zeros = tf.zeros_like(input_tensor)
with tf.Session() as sess:
print(sess.run(zeros, feed_dict={input_tensor: [[1,2,3]]}))
# [[0 0 0]]