Tensorflow 2.0: 这相当于numpy.take_along_axis的意思。

问题描述 投票:0回答:1

这是我的问题:我已经实现了一个简单的函数,返回信号的峰值,组织成一个矩阵。

@tf.function
def get_peaks(X, X_err):
    prominence = 0.9
    # X shape (B, N, 1)
    max_pooled = tf.nn.pool(X, window_shape=(20, ), pooling_type='MAX', padding='SAME') 
    maxima = tf.equal(X, max_pooled) #shape (1, N, 1)
    maxima = tf.cast(maxima, tf.float32)
    peaks = tf.squeeze(X * maxima) #shape (1, N, 1) ==> shape (N,)
    peaks_err = X_err * tf.squeeze(maxima)
    peaks_idxs, idxs = tf.math.top_k(peaks, k=2)
    return peaks_idxs, idxs 

如你所见,输入的信号形状是 (B, N, 1)即批次样本,每个样本都是一个N个元素的单维向量。idxs 是正确的,也是 peaks_idxs它们的形状是(B, 2),即该批次中每个样品的两个最大值的位置(和峰值)。

问题是,我想也采取了 peak_err 与...相对应 idxs. 随着 numpy 我将使用。

np.take_along_axis(peaks_err, idxs, axis=1)

那实际上是返回一个正确的矩阵与形状 (B, 2). 我怎么能用tf做同样的事情呢?其实我已经尝试过使用 tf.gather:

tf.gather(peaks_err, idxs, axis=1)

但它是不工作的,结果是不正确的形状(B,B,2),和大量的零.你知道我怎么能解决?谢谢!我的问题是

python python-3.x slice tensorflow2.0
1个回答
0
投票

我已经解决了添加三行。

@tf.function
def get_local_maxima3(XC, SXC):
    prominence = 0.9
    # x shape (1, N, 1)
    max_pooled = tf.nn.pool(XC, window_shape=(20, ), pooling_type='MAX', padding='SAME') 
    maxima = tf.equal(XC, max_pooled) #shape (1, N, 1)
    maxima = tf.cast(maxima, tf.float32)
    peaks = tf.squeeze(XC * maxima) #shape (1, N, 1) ==> shape (N,)
    peaks_err = SXC * tf.squeeze(maxima)
    #maxima = tf.where(tf.greater(peaks, prominence)) # shape (N,)
    peaks, idxs = tf.math.top_k(peaks, k=2)

    idxs_shape = tf.shape(idxs)
    grid = tf.meshgrid(*(tf.range(idxs_shape[i]) for i in range(idxs.shape.ndims)), indexing='ij')
    index_full = tf.stack(grid[:-1] + [idxs], axis=-1)
    peaks_err = tf.gather_nd(peaks_err, index_full)
    return peaks, peaks_err

它的工作原理!如果你发现有一个更聪明更快速的解决方案,我会很感激它。

© www.soinside.com 2019 - 2024. All rights reserved.