我有一个设置,我需要生成一些随机数,这些随机数由
vmap
消耗,然后由 lax.scan
消耗:
def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
...
return num.astype(int)
def forward(key: Array, input: Array) -> Array:
k = generate_random(key, 1, 5)
computation = model(.., k, ..)
...
# Computing the forward pass
output = jax.vmap(forward, in_axes=.....
但是尝试将
num
从 jax.Array
转换为 int32
会导致 ConcretizationError
。
这可以通过这个最小示例重现:
@jax.jit
def t():
return jnp.zeros((1,)).item().astype(int)
o = t()
o
JIT 要求所有操作都是 Jax 类型。
但是
vmap
隐式使用JIT。出于性能原因,我更愿意保留它。
这是我的黑客尝试:
@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
key, subkey = jax.random.split(key)
random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
return random_number.astype(int)
def react_forward(key: Array, input: Array) -> Array:
k = get_rand_num(key, 1, MAX_ITERS)
# forward pass the model without tracking grads
intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
...
return ...
a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
其中涉及创建
batch_size
#子密钥以在vmap
(a.shape[0]
)期间的每批使用,从而获得随机数。
但是它不起作用,因为
k
是从 jax.Array -> int
投射的。
但做出这些改变:
- k = get_rand_num(key, 1, MAX_ITERS) + k = 5 # any hardcoded int
工作完美。显然,采样导致了这里的问题......
您无法在
jit
、vmap
或任何其他 JAX 转换中将跟踪值转换为 Python 整数。您的最小示例的问题是对 .item()
的调用,它尝试将 JAX 标量转换为 Python 标量。
您可以通过避免此强制转换来解决此问题。这是函数的新版本,它返回零维整数数组,这就是 JAX 对整数标量进行编码的方式:
@jax.jit
def t():
return jnp.zeros((1,)).astype(int).reshape(())
也就是说,您如此关心从数组创建整数这一事实让我认为您的
model
函数要求其第二个参数是静态的,不幸的是,在这种情况下,上面的内容对您没有帮助,因为 不可能将 JAX 转换中的跟踪值转换为静态值。
如果您详细说明
# THIS LINE ERRORS OUT
(例如,错误是什么?您能提供该错误的最小可重现示例吗?),我们也许能够提供更多帮助,但如果没有,这只是猜测。