我刚刚开始使用 JAX,我想知道在 JAX/Python 中实现 if-then-elif-then-else 的正确方法是什么?例如,给定输入数组:
n = [5, 4, 3, 2]
和k = [3, 3, 3, 3]
,我需要实现以下伪代码:
def n_choose_k_safe(n, k):
r = jnp.empty(4)
for i in range(4):
if n[i] < k[i]:
r[i] = 0
elif n[i] == k[i]:
r[i] = 1
else:
r[i] = func_nchoosek(n[i], k[i])
return r
有很多选择,如
vmap
、lax.select
、lax.where
、jax.cond
、lax.fori_loop
等,因此很难决定要使用的实用程序的具体组合。顺便说一句,k
可以是标量(如果这样更简单的话)。
Valentin 的答案中有一种稍微更紧凑的方式来表达解决方案,使用
jax.numpy.select
:
def n_choose_k_safe(n, k):
return jnp.select(condlist=[n > k, n == k],
choicelist=[jnp.vectorize(func_nchoosek)(n, k), 1],
default=0)
对于长度为 4 的输入数组,假设
func_nchoosek
与 jax.vmap
兼容,这应该返回与原始代码相同的结果。此处使用 vectorize
代替 vmap
将使该函数也与 k
的标量输入兼容,而无需手动设置 in_axes
参数。
首先,您可以对函数
func_nchoosek
进行向量化,以接受 n
和 k
作为平面向量(我们假设 func_nchoosek
接受形状为 (1, ) 的输入,否则应该首先这样做),然后:
func_nchoosek_vect = jax.vmap(func_nchoosek, (0, 0), 0)
现在
func_nchoosek_vect([n1, n2, ...], [k1, k2, ...]) = [func_nchoosek(n1, k1), func_nchoosek(n2, k2), ...]
操作是逐元素完成的(类似于zip
)。
如果
k
是单个标量,您可以使用它:
func_nchoosek_vect = jax.vmap(func_nchoosek, (0, None), 0)
现在您可以使用功能
jnp.where
来选择您想要的数据。它类似于 lax.select
但更灵活。该函数与 jit 编译兼容,并保留梯度(在一些进一步的假设下),只要您将它与 3 个参数一起使用(以获得确定性形状)。
def n_choose_k_safe(n: jnp.array, k: jnp.array) -> jnp.array:
"""Choose k among n with safety."""
r = jnp.where(n > k, func_nchoosek_vect(n, k), -1)
r = jnp.where(n == k, 1, r)
r = jnp.where(n < k, 0, r)
return r