我有一个形状为conv_size x H x W x C的转换层conv的输出。我还有另一个具有Batch_size x None x 2形状的张量。后一个张量提供了一个点列表(高度和宽度坐标) bach中的每个示例(每个示例的列表长度不同)。我想为每个点提取Channel维度。
我尝试使用tf.gather和tf.batch_gather,但这两个似乎都不适合在这里使用。
基本上我想要的是每个批次b遍历点:对于每个点我有其h_i(高度坐标)和w_i(坐标)并返回conv [b,h_i,w_j,:]。然后叠加这些结果。
以下是如何做到这一点:
import tensorflow as tf
def pick_points(images, coords):
coords = tf.convert_to_tensor(coords)
s = tf.shape(coords)
batch_size, num_coords = s[0], s[1]
# Make batch indices
r = tf.range(batch_size, dtype=coords.dtype)
idx_batch = tf.tile(tf.expand_dims(r, 1), [1, num_coords])
# Full index
idx = tf.concat([tf.expand_dims(idx_batch, 2), coords], axis=2)
# Gather pixels
pixels = tf.gather_nd(images, idx)
# Output has shape [batch_size, num_coords, num_channels]
return pixels
# Test
with tf.Graph().as_default(), tf.Session() as sess:
# 2 x 2 x 3 x 3
images = [
[
[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
],
[
[[19, 20, 21], [22, 23, 24], [25, 26, 27]],
[[28, 29, 30], [31, 32, 33], [34, 35, 36]],
],
]
# 2 x 2 x 2
coords = [
[[0, 1], [1, 2]],
[[1, 0], [1, 1]],
]
pixels = pick_points(images, coords)
print(sess.run(pixels))
# [[[ 4 5 6]
# [16 17 18]]
#
# [[28 29 30]
# [31 32 33]]]