我正在构建一个拓扑结构的计算图,该图根据一些超参数而变化。在某些时候,会发生串联:
c = tf.concat([a, b], axis=-1)
张量a
有形状(None, m)
。张量b
形状(None, n)
n
依赖于超参数。对于超参数的一个值,张量b
在概念上应该是空的,例如,我们希望c
和a
是一样的。
我可以使用以下代码成功构建图形:
b = tf.placeholder(tf.float32, (None, 0), name="Empty")
但是,如果我运行一个会话,TensorFlow会引发一个InvalidArgumentError
说:
You must feed a value for placeholder tensor 'Empty' with dtype float and shape [?,0]
有没有办法构造一个在concat
操作中表现为空的张量,但不需要输入虚假输入?
显然,我知道我可以在构建图形的代码中添加一个特殊情况,包装器等。我希望避免这种情况。
完整代码:
import tensorflow as tf
import numpy as np
a = tf.placeholder(tf.float32, (None, 10))
b = tf.placeholder(tf.float32, (None, 0), name="Empty")
c = tf.concat([a, b], axis=-1)
assert c.shape.as_list() == [None, 10]
with tf.Session() as sess:
a_feed = np.zeros((100, 10))
c = sess.run(c, {a : a_feed})
您可以使用不需要占位符的tf.placeholder_with_default。
import tensorflow as tf
import numpy as np
# Hparams
batch_size = 100
a_dim = 10
b_dim = 0
# Placeholder for a which is required to be fed.
a = tf.placeholder(tf.float32, (None, a_dim))
# Placeholder for b, which doesn't have to be fed.
b_default = np.zeros((batch_size, b_dim), dtype=np.float32)
b = tf.placeholder_with_default(
b_default, (None, b_dim), name="Empty"
)
c = tf.concat([a, b], axis=-1)
assert c.shape.as_list() == [None, a_dim + b_dim]
with tf.Session() as sess:
a_feed = np.zeros((batch_size, a_dim))
b_feed = np.ones((batch_size, b_dim))
c_out = sess.run(c, {a : a_feed})
# You can optionally feed in b:
# c_out = sess.run(c, {a : a_feed, b : b_feed})
print(c_out)
如果您不使用tf.placeholder()
来提供数据,但tf.Estimator
,那么解决方案是微不足道的,因为您可以定义:
b = tf.zeros([a.shape[0].value, 0])
所以,如果a的形状是已知的,
c = tf.concat([a,b],axis=-1)
assert c.shape == a.shape
永远都会成功。