我有一个 JAX 和 Equinox 的训练流程。我想将批量索引传递给损失函数,以便根据索引应用不同的逻辑。如果没有批量索引训练循环,则工作时间约为 15 秒,但如果我通过索引,则它会减慢大约一个小时。你能解释一下,为什么会发生这种情况吗?我是 JAX 新手,抱歉。
def fit_cv(model: eqx.Module,
dataloader: jdl.DataLoader,
optimizer: optax.GradientTransformation,
loss: tp.Callable,
n_steps: int = 1000):
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
dloss = eqx.filter_jit(eqx.filter_value_and_grad(loss))
@eqx.filter_jit
def step(model, data, opt_state, batch_index):
loss_score, grads = dloss(model, data, batch_index)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_score
loss_history = []
for batch_index, batch in tqdm(zip(range(n_steps), dataloader), total=n_steps):
if batch_index >= n_steps:
break
batch = batch[0] # dataloader returns tuple of size (1,)
model, opt_state, loss_score = step(model, batch, opt_state, batch_index)
loss_history.append(loss_score)
return model, loss_history
损失函数具有以下签名
def loss(self, model: eqx.Module, data: jnp.ndarray, batch_index: int):
我怀疑问题是过度重新编译。您正在使用
filter_jit
,根据 docs 具有以下属性:
所有 JAX 和 NumPy 数组都会被跟踪,所有其他类型都保持静态。
每次 JIT 编译函数的静态参数发生变化时,都会触发重新编译。这意味着每次您使用新值
batch_index
调用函数时,该函数都会被重新编译。
作为修复,我建议使用常规的旧
jax.jit
,它要求您显式指定静态参数(像这样的潜在惊喜是 JAX 做出此设计选择的原因之一!)。只要您不将 batch_index
标记为静态,您就不应该看到这种重新编译的惩罚。
或者,如果您想继续使用
filter_jit
,那么您可以将步骤调用更改为:
step(model, batch, opt_state, jnp.asarray(batch_index))
通过此更改,
filter_jit
将不再决定使批次索引静态。当然,这些建议中的任何一个都要求 loss
与动态 batch_index
兼容,这无法从您的问题中包含的信息确定。