编译的 JAX 函数无缘无故变慢

问题描述 投票:0回答:1

我正在使用 Jax 进行科学计算,具体来说,并行计算系统中所有粒子之间的成对相互作用力。这是为了模拟细菌种群的动态。我实现了一个编译函数,该函数返回相邻粒子对主粒子施加的力和扭矩,并使用 vmap 并行执行所有计算。我们旧模拟的主要问题是速度,我想通过力计算的并行化来运行得更快。

@jax.jit
def calc_interaction_forces_and_torques(p_main, q_main, p_neighb, q_neighb, activity_neighb,length, space_size,radius,w_a,k_cc):

neighbor_mapped_calc_int_forces = jax.vmap(calc_interaction_forces_and_torques, in_axes=(None, None, 0, 0, 0, None, None, None, None, None))

fully_vmapped_calc_int_forces = jax.vmap(neighbor_mapped_calc_int_forces, in_axes=(0, 0, None, None,None, None, None, None, None, None), out_axes=1)

我使用 jupyter 笔记本电脑对其进行计时。编译函数后,%%timeit 命令平均执行时间为 2.1 毫秒。

%%timeit

fully_vmapped_calc_int_forces(p, q, p, q, activity,length, space_size,radius,w_a,k_cc)

>>> 2.32 ms ± 180 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

但是,我有另一个函数可以计算单个前向时间步长。

@jax.jit
def one_timestep_loop(pos, theta, activity, space_size, length, radius, w_a, k_cc, F_prop, T_prop, activity_thresh, dt, gamma):
    pos = update_position_periodic(pos, space_size)
    p, q = endPoints(pos, theta, length)
        
    
    all_forces, all_torques, all_act_updates = fully_vmapped_calc_int_forces(p, q, p, q, activity, length,space_size, radius, w_a, k_cc)
    
    F_sum = jnp.sum(all_forces, axis=1)
    
    activity_term = activity[:, None] / length  # Precompute activity term without reshaping
    F_sum += F_prop * (p - pos) * activity_term  # Vectorized form

    T_sum = jnp.sum(all_torques, axis=1)
    T_sum += T_prop * activity  # Already vectorized
    
    local_density = jnp.sum(all_act_updates, axis=1)

    activity = activation(local_density, activity_thresh)

    v, w = update_velocity(p, q, F_sum, T_sum, gamma, length)
    
    pos, theta = update_position(pos, theta, v, w, dt, space_size)

    return pos, theta, activity

上述函数中的所有其他步骤对我来说似乎都很微不足道,这意味着它们没有理由对我的模拟造成任何瓶颈。然而,他们确实如此。


tic = time.time()
pos, theta, activity = one_timestep_loop(pos, theta, activity, space_size, length, radius, w_a, k_cc, F_prop, T_prop, activity_thresh, dt, gamma)
toc = time.time()
print(1000*(toc - tic))

当我第一次运行此代码块时,编译需要 800 毫秒。第二次需要 500 毫秒左右,之后,它会在 1 - 2 毫秒内执行。更重要的是,这是我不知道发生了什么的部分,%%timeit 命令返回平均 46.5 毫秒的执行时间

%%timeit
one_timestep_loop(pos, theta, activity, space_size, length, radius, w_a, k_cc, F_prop, T_prop, activity_thresh, dt, gamma)

>>> 46.5 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

当我尝试使用

lax.fori_loop
来模拟动态时,

def simulation_loop(t, loop_carry):
    # Unpack the carry variables
    pos, theta, activity = loop_carry

    # Call the function that wraps all updates per timestep
    pos, theta, activity = one_timestep_loop(
        pos, theta, activity, space_size, length, radius, w_a, k_cc, F_prop, T_prop, activity_thresh, dt, gamma
    )
    
    # Return updated variables to be passed to the next loop iteration
    return pos, theta, activity

# Main simulation function
def run_simulation(pos, theta, activity, num_steps):
    # Use lax.fori_loop for the loop execution
    pos, theta, activity = lax.fori_loop(
        1, num_steps + 1, simulation_loop, (pos, theta, activity)
    )
    return pos, theta, activity

num_steps = 100
tic = time.time()
pos, theta, activity = run_simulation(pos, theta, activity, num_steps)
toc = time.time()

print(f"Time it takes to step {num_steps}: ", (toc - tic)*1000, " ms")

此代码的执行有些需要 3000 毫秒,有些需要 2 毫秒,有时需要 400 毫秒。当我使用 python for 循环或 lax fori_loop 进行循环时,我的模拟仍然花费太多时间,尽管我已经实现了力计算的并行化(令人头痛的部分)并且它的执行速度非常快。

也许内存开销造成了一些麻烦,但我不知道。我非常绝望并寻求您的帮助,因为这导致我的研究速度减慢。

任何形式的帮助将不胜感激

谢谢!

我预计通过 jax 并行化加强力计算会提高模拟速度,但事实并非如此。函数的独立执行(全部用 jax.jit 编译)有时会给出意想不到的缓慢答案,并且循环时间变量以继续动态仍然非常慢

parallel-processing scientific-computing jax
1个回答
0
投票

运行 JAX 代码的微基准测试时,您应该记住JAX 常见问题解答:对 JAX 代码进行基准测试中提到的提示。

特别是,JAX 的 异步调度 可以解释您在这里看到的奇怪的时间安排。为了解决这个问题,请务必在您正在测量其执行时间的任何函数的输出上调用

block_until_ready
。上面链接的常见问题解答条目中有这样的示例。

© www.soinside.com 2019 - 2024. All rights reserved.