我有以下代码,其中我的 sim_timestep 是批量的 我无法运行它,因为 lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) ) 需要具体数组,但由于我已经映射了 state_predictor 函数,所以 sim_timestep 被作为一个 tracedArray 。 任何帮助将不胜感激 。 谢谢大家
from jax import random
from jax import lax
import jax
import jax.numpy as jnp
import pdb
def fwd_dynamics(x_u, xs):
x0,uk = x_u
Delta_T = 0.001
lwb = 1.2
psi0=x0[2][0]
v0= x0[3][0]
vdot0 = uk[0][0]
delta0 = uk[1][0]
thetadot0 = uk[2][0]
xdot= jnp.asarray([[v0*jnp.cos(psi0) ],
[v0*jnp.sin(psi0)] ,
[v0*jnp.tan(delta0)/(lwb)],
[vdot0],
[thetadot0]])
x_next = x0 + xdot*Delta_T
return (x_next,uk), x_next # ("carryover", "accumulated")
def state_predictor( xk,uk ,sim_timestep):
(x_next,_), _ = lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) )
return x_next
low = 0 # Adjust minimum value as needed
high = 100 # Adjust maximum value as needed
key = jax.random.PRNGKey(44)
sim_time = jax.random.randint(key, shape=(10, 1), minval=low, maxval=high)
xk = jax.random.uniform(key, shape=(10,5, 1))
uk = jax.random.uniform(key, shape=(10,2, 1))
state_predictor_vmap = jax.jit(jax.vmap(state_predictor,in_axes= 0 ,out_axes=0 ))
x_next = state_predictor_vmap( xk,uk ,sim_time)
print(x_next.shape)
我尝试通过上面的代码解决它,希望能找到替代方法来实现相同的功能。
您要求做的事情是不可能的:
scan
长度必须是静态的,并且根据定义,vmapped值是非静态的。
您可以做的是将
scan
替换为 fori_loop
或 while_loop
,然后循环边界不需要是静态的。例如,如果您以这种方式实现函数并保持其余代码不变,它应该可以工作:
def state_predictor(xk, uk, sim_timestep):
body_fun = lambda i, x_u: fwd_dynamics(x_u, i)[0]
x_next, _ = lax.fori_loop(0, sim_timestep[0], body_fun, (xk, uk))
return x_next