Jax:在**JIT**下生成随机数

问题描述 投票:0回答:1

我有一个设置,我需要生成一些随机数,这些随机数由

vmap
消耗,然后由
lax.scan
消耗:

def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
    ...
    return num.astype(int)

def forward(key: Array, input: Array) -> Array:
    k = generate_random(key, 1, 5)
    computation = model(.., k, ..)
    ...

# Computing the forward pass
output = jax.vmap(forward, in_axes=.....

但是尝试将

num
jax.Array
转换为
int32
会导致
ConcretizationError

这可以通过这个最小示例重现:

@jax.jit
def t():
  return jnp.zeros((1,)).item().astype(int)

o = t()
o

JIT 要求所有操作都是 Jax 类型。

但是

vmap
隐式使用JIT。出于性能原因,我更愿意保留它。


我的尝试

这是我的黑客尝试:

@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
  key, subkey = jax.random.split(key)
  random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
  return random_number.astype(int)

def react_forward(key: Array, input: Array) -> Array:
  k = get_rand_num(key, 1, MAX_ITERS)
  # forward pass the model without tracking grads
  intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
  ...
  return ...

a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

其中涉及创建

batch_size
#子密钥以在
vmap
a.shape[0]
)期间的每批使用,从而获得随机数。

但是它不起作用,因为

k
是从
jax.Array -> int
投射的。

但做出这些改变:

-  k = get_rand_num(key, 1, MAX_ITERS)
+  k = 5 # any hardcoded int

工作完美。显然,采样导致了这里的问题......

python random equinox jax
1个回答
0
投票

您无法在

jit
vmap
或任何其他 JAX 转换中将跟踪值转换为 Python 整数。您的最小示例的问题是对
.item()
的调用,它尝试将 JAX 标量转换为 Python 标量。

您可以通过避免此强制转换来解决此问题。这是函数的新版本,它返回零维整数数组,这就是 JAX 对整数标量进行编码的方式:

@jax.jit
def t():
  return jnp.zeros((1,)).astype(int).reshape(())

也就是说,您如此关心从数组创建整数这一事实让我认为您的

model
函数要求其第二个参数是静态的,不幸的是,在这种情况下,上面的内容对您没有帮助,因为 不可能将 JAX 转换中的跟踪值转换为静态值

如果您详细说明

 # THIS LINE ERRORS OUT
(例如,错误是什么?您能提供该错误的最小可重现示例吗?),我们也许能够提供更多帮助,但如果没有,这只是猜测。

© www.soinside.com 2019 - 2024. All rights reserved.