向损失函数添加整数参数后,JAX/Equinox 管道速度变慢

问题描述 投票:0回答:1

我有一个 JAX 和 Equinox 的训练流程。我想将批量索引传递给损失函数,以便根据索引应用不同的逻辑。如果没有批量索引训练循环,则工作时间约为 15 秒,但如果我通过索引,则它会减慢大约一个小时。你能解释一下,为什么会发生这种情况吗?我是 JAX 新手,抱歉。

def fit_cv(model: eqx.Module, 
           dataloader: jdl.DataLoader, 
           optimizer: optax.GradientTransformation, 
           loss: tp.Callable, 
           n_steps: int = 1000):
    
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    dloss = eqx.filter_jit(eqx.filter_value_and_grad(loss))
    
    @eqx.filter_jit
    def step(model, data, opt_state, batch_index):
        loss_score, grads = dloss(model, data, batch_index)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_score
    
    loss_history = []
    for batch_index, batch in tqdm(zip(range(n_steps), dataloader), total=n_steps):
        if batch_index >= n_steps:
            break
        batch = batch[0] # dataloader returns tuple of size (1,)
        model, opt_state, loss_score = step(model, batch, opt_state, batch_index)
        loss_history.append(loss_score)
    return model, loss_history

损失函数具有以下签名

def loss(self, model: eqx.Module, data: jnp.ndarray, batch_index: int):
jax equinox
1个回答
0
投票

我怀疑问题是过度重新编译。您正在使用

filter_jit
,根据 docs 具有以下属性:

所有 JAX 和 NumPy 数组都会被跟踪,所有其他类型都保持静态。

每次 JIT 编译函数的静态参数发生变化时,都会触发重新编译。这意味着每次您使用新值

batch_index
调用函数时,该函数都会被重新编译。

作为修复,我建议使用常规的旧

jax.jit
,它要求您显式指定静态参数(像这样的潜在惊喜是 JAX 做出此设计选择的原因之一!)。只要您不将
batch_index
标记为静态,您就不应该看到这种重新编译的惩罚。

或者,如果您想继续使用

filter_jit
,那么您可以将步骤调用更改为:

step(model, batch, opt_state, jnp.asarray(batch_index))

通过此更改,

filter_jit
将不再决定使批次索引静态。当然,这些建议中的任何一个都要求
loss
与动态
batch_index
兼容,这无法从您的问题中包含的信息确定。

© www.soinside.com 2019 - 2024. All rights reserved.