我有一个名为tensor
形状[batch_size, axis_1, axis_2]
的等级-3张量,并希望沿着第一轴将其分成batch_size
切片,如下所示:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
不幸的是,这不起作用,因为在构建图形期间尚未知道batch_size
的值。
我怎么解决这个问题?
我收到此错误:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
奇怪的是,尝试在其他TensorFlow函数中使用batch_size
似乎有效:
tensor = tf.reshape(tensor, [batch_size, -1])
尽管在图形构造过程中batch_size
的值是未知的,但工作正常。
tf.split()
问题特别严重吗?
解决方法是做:
batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
elems=tf.range(batch_size),
dtype=tf.float32)
我仍然对更好的解决方案感兴趣。