我有一个方形的 numpy.ndarray 和一个形状相同的 numpy 布尔掩码。我想找到数组每行中第一个未被屏蔽的元素。
我的代码目前依赖于 numpy.ma.notmasked_edges(),它正是我所需要的。但是,我现在需要将代码迁移到 JAX,JAX 尚未在 jax.numpy 中实现 numpy.ma。
仅调用 JAX 中已实现的 numpy 函数(不包括 numpy.ma)来查找每行中第一个未屏蔽元素的索引的最简单方法是什么?
我试图重现的代码类似于:
import numpy as np
my_array = np.random.rand(5,5)
mask = (my_array < 0.5)
my_masked_array = np.ma.masked_array(my_array, mask=mask)
np.ma.notmasked_edges(my_masked_array, axis=1)[0]
我确信有很多方法可以做到这一点,但我正在寻找最不笨重的方法。
这是
nonmasked_edges
的 JAX 实现,它采用布尔掩码并返回与 numpy.ma
函数返回的相同索引:
import jax.numpy as jnp
def notmasked_edges(mask, axis=None):
mask = jnp.asarray(mask)
assert mask.dtype == bool
if axis is None:
mask = mask.ravel()
axis = 0
shape = list(mask.shape)
del shape[axis]
alltrue = mask.all(axis=axis).ravel()
indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
indices = [jnp.ravel(ind)[~alltrue] for ind in indices]
first = indices.copy()
first.insert(axis, jnp.argmin(mask, axis=axis).ravel()[~alltrue])
last = indices.copy()
last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel()[~alltrue])
return [tuple(first), tuple(last)]
这与 JIT 不兼容,因为输出数组的大小取决于掩码的值(没有未掩码值的行将被排除)。
如果您想要 JIT 兼容版本,您可以删除
[~alltrue]
索引,对于没有未屏蔽值的行,将返回第一个/最后一个索引:
def notmasked_edges_v2(mask, axis=None):
mask = jnp.asarray(mask)
assert mask.dtype == bool
if axis is None:
mask = mask.ravel()
axis = 0
shape = list(mask.shape)
del shape[axis]
indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
indices = [jnp.ravel(ind) for ind in indices]
first = indices.copy()
first.insert(axis, jnp.argmin(mask, axis=axis).ravel())
last = indices.copy()
last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel())
return [tuple(first), tuple(last)]
这是一个例子:
import numpy as np
mask = np.array([[True, False, False, True],
[False, False, True, True],
[True, True, True, True]])
arr = np.ma.masked_array(np.ones_like(mask), mask=mask)
print(np.ma.notmasked_edges(arr, axis=1))
# [(array([0, 1]), array([1, 0])), (array([0, 1]), array([2, 1]))]
print(notmasked_edges(mask, axis=1))
# [(Array([0, 1], dtype=int32), Array([1, 0], dtype=int32)),
# (Array([0, 1], dtype=int32), Array([2, 1], dtype=int32))]
print(notmasked_edges_v2(mask, axis=1))
# [(Array([0, 1, 2], dtype=int32), Array([1, 0, 0], dtype=int32)),
# (Array([0, 1, 2], dtype=int32), Array([2, 1, 3], dtype=int32))]