Tensorflow:根据来自另一个张量的信息从张量中选择一个非重叠切片列表

问题描述 投票:0回答:1

我有一个形状为conv_size x H x W x C的转换层conv的输出。我还有另一个具有Batch_size x None x 2形状的张量。后一个张量提供了一个点列表(高度和宽度坐标) bach中的每个示例(每个示例的列表长度不同)。我想为每个点提取Channel维度。

我尝试使用tf.gather和tf.batch_gather,但这两个似乎都不适合在这里使用。

基本上我想要的是每个批次b遍历点:对于每个点我有其h_i(高度坐标)和w_i(坐标)并返回conv [b,h_i,w_j,:]。然后叠加这些结果。

python tensorflow slice
1个回答
1
投票

以下是如何做到这一点:

import tensorflow as tf

def pick_points(images, coords):
    coords = tf.convert_to_tensor(coords)
    s = tf.shape(coords)
    batch_size, num_coords = s[0], s[1]
    # Make batch indices
    r = tf.range(batch_size, dtype=coords.dtype)
    idx_batch = tf.tile(tf.expand_dims(r, 1), [1, num_coords])
    # Full index
    idx = tf.concat([tf.expand_dims(idx_batch, 2), coords], axis=2)
    # Gather pixels
    pixels = tf.gather_nd(images, idx)
    # Output has shape [batch_size, num_coords, num_channels]
    return pixels

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    # 2 x 2 x 3 x 3
    images = [
        [
            [[ 1,  2,  3], [ 4,  5,  6], [ 7,  8,  9]],
            [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
        ],
        [
            [[19, 20, 21], [22, 23, 24], [25, 26, 27]],
            [[28, 29, 30], [31, 32, 33], [34, 35, 36]],
        ],
    ]
    # 2 x 2 x 2
    coords = [
        [[0, 1], [1, 2]],
        [[1, 0], [1, 1]],
    ]
    pixels = pick_points(images, coords)
    print(sess.run(pixels))
    # [[[ 4  5  6]
    #   [16 17 18]]
    #
    #  [[28 29 30]
    #   [31 32 33]]]
© www.soinside.com 2019 - 2024. All rights reserved.