具有宽松扫描的 Jax vmap 在批量维度中具有不同的序列长度

问题描述 投票:0回答:1

我有以下代码,其中我的 sim_timestep 是批量的 我无法运行它,因为 lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) ) 需要具体数组,但由于我已经映射了 state_predictor 函数,所以 sim_timestep 被作为一个 tracedArray 。 任何帮助将不胜感激 。 谢谢大家

from jax import random
from jax import lax
import jax
import jax.numpy as jnp
import pdb


def fwd_dynamics(x_u, xs):
    x0,uk =  x_u
    Delta_T = 0.001
    lwb = 1.2
    psi0=x0[2][0]
    v0= x0[3][0]
    vdot0 = uk[0][0]
    delta0 = uk[1][0]
    thetadot0 = uk[2][0]
        
    xdot= jnp.asarray([[v0*jnp.cos(psi0) ],
        [v0*jnp.sin(psi0)] ,
        [v0*jnp.tan(delta0)/(lwb)],
        [vdot0],
        [thetadot0]])
    x_next = x0 + xdot*Delta_T
    return (x_next,uk), x_next  # ("carryover", "accumulated")


def state_predictor( xk,uk ,sim_timestep):
    (x_next,_), _ = lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) )
    return x_next

low = 0  # Adjust minimum value as needed
high = 100  # Adjust maximum value as needed
key = jax.random.PRNGKey(44)

sim_time = jax.random.randint(key, shape=(10, 1), minval=low, maxval=high)

xk = jax.random.uniform(key, shape=(10,5, 1))
uk = jax.random.uniform(key, shape=(10,2, 1))

state_predictor_vmap = jax.jit(jax.vmap(state_predictor,in_axes= 0 ,out_axes=0 ))
x_next = state_predictor_vmap( xk,uk ,sim_time)
print(x_next.shape)

我尝试通过上面的代码解决它,希望能找到替代方法来实现相同的功能。

python pytorch jax
1个回答
0
投票

您要求做的事情是不可能的:

scan
长度必须是静态的,并且根据定义,vmapped值是非静态的。

您可以做的是将

scan
替换为
fori_loop
while_loop
,然后循环边界不需要是静态的。例如,如果您以这种方式实现函数并保持其余代码不变,它应该可以工作:

def state_predictor(xk, uk, sim_timestep):
  body_fun = lambda i, x_u: fwd_dynamics(x_u, i)[0]
  x_next, _ = lax.fori_loop(0, sim_timestep[0], body_fun, (xk, uk))
  return x_next
© www.soinside.com 2019 - 2024. All rights reserved.