在jax中实现if-then-elif-then-else

问题描述 投票:0回答:2

我刚刚开始使用 JAX,我想知道在 JAX/Python 中实现 if-then-elif-then-else 的正确方法是什么?例如,给定输入数组:

n = [5, 4, 3, 2]
k = [3, 3, 3, 3]
,我需要实现以下伪代码:

def n_choose_k_safe(n, k):
    r = jnp.empty(4)
    for i in range(4):
        if n[i] < k[i]:
            r[i] = 0
        elif n[i] == k[i]:
            r[i] = 1
        else:
            r[i] = func_nchoosek(n[i], k[i])
    return r

有很多选择,如

vmap
lax.select
lax.where
jax.cond
lax.fori_loop
等,因此很难决定要使用的实用程序的具体组合。顺便说一句,
k
可以是标量(如果这样更简单的话)。

python jax cudnn
2个回答
4
投票

Valentin 的答案中有一种稍微更紧凑的方式来表达解决方案,使用

jax.numpy.select
:

def n_choose_k_safe(n, k):
  return jnp.select(condlist=[n > k, n == k],
                    choicelist=[jnp.vectorize(func_nchoosek)(n, k), 1],
                    default=0)

对于长度为 4 的输入数组,假设

func_nchoosek
jax.vmap
兼容,这应该返回与原始代码相同的结果。此处使用
vectorize
代替
vmap
将使该函数也与
k
的标量输入兼容,而无需手动设置
in_axes
参数。


3
投票

首先,您可以对函数

func_nchoosek
进行向量化,以接受
n
k
作为平面向量(我们假设
func_nchoosek
接受形状为 (1, ) 的输入,否则应该首先这样做),然后:

func_nchoosek_vect = jax.vmap(func_nchoosek, (0, 0), 0)

现在

func_nchoosek_vect([n1, n2, ...], [k1, k2, ...]) = [func_nchoosek(n1, k1), func_nchoosek(n2, k2), ...]
操作是逐元素完成的(类似于
zip
)。

如果

k
是单个标量,您可以使用它:

func_nchoosek_vect = jax.vmap(func_nchoosek, (0, None), 0)

现在您可以使用功能

jnp.where
来选择您想要的数据。它类似于
lax.select
但更灵活。该函数与 jit 编译兼容,并保留梯度(在一些进一步的假设下),只要您将它与 3 个参数一起使用(以获得确定性形状)。

def n_choose_k_safe(n: jnp.array, k: jnp.array) -> jnp.array:
  """Choose k among n with safety."""
  r = jnp.where(n > k, func_nchoosek_vect(n, k), -1)
  r = jnp.where(n == k, 1, r)
  r = jnp.where(n < k, 0, r)
  return r
© www.soinside.com 2019 - 2024. All rights reserved.