用于 JAX 的 numpy.ma.notmasked_edges() 的最简单等效实现

问题描述 投票:0回答:1

我有一个方形的 numpy.ndarray 和一个形状相同的 numpy 布尔掩码。我想找到数组每行中第一个未被屏蔽的元素。

我的代码目前依赖于 numpy.ma.notmasked_edges(),它正是我所需要的。但是,我现在需要将代码迁移到 JAX,JAX 尚未在 jax.numpy 中实现 numpy.ma。

仅调用 JAX 中已实现的 numpy 函数(不包括 numpy.ma)来查找每行中第一个未屏蔽元素的索引的最简单方法是什么?

我试图重现的代码类似于:

import numpy as np
my_array = np.random.rand(5,5)
mask = (my_array < 0.5)
my_masked_array = np.ma.masked_array(my_array, mask=mask)
np.ma.notmasked_edges(my_masked_array, axis=1)[0]

我确信有很多方法可以做到这一点,但我正在寻找最不笨重的方法。

numpy numpy-ndarray jax numpy-slicing
1个回答
0
投票

这是

nonmasked_edges
的 JAX 实现,它采用布尔掩码并返回与
numpy.ma
函数返回的相同索引:

import jax.numpy as jnp

def notmasked_edges(mask, axis=None):
  mask = jnp.asarray(mask)
  assert mask.dtype == bool
  if axis is None:
    mask = mask.ravel()
    axis = 0
  shape = list(mask.shape)
  del shape[axis]
  alltrue = mask.all(axis=axis).ravel()
  indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
  indices = [jnp.ravel(ind)[~alltrue] for ind in indices]

  first = indices.copy()
  first.insert(axis, jnp.argmin(mask, axis=axis).ravel()[~alltrue])

  last = indices.copy()
  last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel()[~alltrue])
  
  return [tuple(first), tuple(last)]

这与 JIT 不兼容,因为输出数组的大小取决于掩码的值(没有未掩码值的行将被排除)。

如果您想要 JIT 兼容版本,您可以删除

[~alltrue]
索引,对于没有未屏蔽值的行,将返回第一个/最后一个索引:

def notmasked_edges_v2(mask, axis=None):
  mask = jnp.asarray(mask)
  assert mask.dtype == bool
  if axis is None:
    mask = mask.ravel()
    axis = 0
  shape = list(mask.shape)
  del shape[axis]
  indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
  indices = [jnp.ravel(ind) for ind in indices]

  first = indices.copy()
  first.insert(axis, jnp.argmin(mask, axis=axis).ravel())

  last = indices.copy()
  last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel())

  return [tuple(first), tuple(last)]

这是一个例子:

import numpy as np
mask = np.array([[True, False, False, True],
                 [False, False, True, True],
                 [True, True, True, True]])

arr = np.ma.masked_array(np.ones_like(mask), mask=mask)
print(np.ma.notmasked_edges(arr, axis=1))
# [(array([0, 1]), array([1, 0])), (array([0, 1]), array([2, 1]))]

print(notmasked_edges(mask, axis=1))
# [(Array([0, 1], dtype=int32), Array([1, 0], dtype=int32)),
#  (Array([0, 1], dtype=int32), Array([2, 1], dtype=int32))]

print(notmasked_edges_v2(mask, axis=1))
# [(Array([0, 1, 2], dtype=int32), Array([1, 0, 0], dtype=int32)),
#  (Array([0, 1, 2], dtype=int32), Array([2, 1, 3], dtype=int32))]
© www.soinside.com 2019 - 2024. All rights reserved.