当我开始使用
vmap
时,我遇到了 JAX 跟踪器泄漏问题。我有两个函数 grad_loss
及其批处理等效函数 grad_loss_batch
。
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = model(ti, yi[0])
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
return jnp.mean((yi - y_pred) ** 2)
通过运行一些初步实验:
ts = jnp.linspace(0.0, 1.0, 10)
ys = jax.lax.stop_gradient(jnp.ones_like(ts)[..., None])
...
# seem to work ok
with jax.checking_leaks():
loss, grads = grad_loss(model_dde, ts, ys)
# Silently side-effecting, no error ?
loss2, grads2 = grad_loss_batch(model_dde, ts, ys[None, ...])
# Jax tracer leak
with jax.check_tracer_leaks():
loss2, grads2 = grad_loss_batch(model_dde, ts, ys[None, ...])
错误产生:
(jdb) grad_loss_batch(model_dde, ts, ys[None, ...])
*** Exception: Leaked trace MainTrace(3,BatchTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<BatchTrace(level=3/0)> with
val = Array([[1.]], dtype=float32)
batch_dim = 0
This BatchTracer with object id 132573553507072 was created on line:
/home/monsel/Desktop/dev_diffrax/mwe.py:69 (grad_loss_batch)
<BatchTracer 132573553507072> is referred to by <list 132573501535936>[0]
<list 132573501535936> is referred to by <PjitParams 132573554102464>[0]
<PjitParams 132573554102464> is referred to by <InferParamsCacheEntry 132573553590720>
我想知道是否有人可以提供指示以开始调试问题。我知道问题取决于
model
及其定义。我宁愿把它作为一个抽象,因为它是一个相当大的代码库,用户必须深入研究。尽管如此,我可以提供一个“大型”MWE 来重现这一点。
当函数有副作用时,通常会发生跟踪器泄漏,例如在缓存中存储的值持续超出了跟踪函数的范围。这是此类示踪剂泄漏的一个最小示例:
import jax
class Model:
def __call__(self, x):
self.cache = x # persistently storing a traced value -> tracer leak
return x
jax.tree_util.register_dataclass(Model, data_fields=[], meta_fields=[])
model = Model()
x = jax.numpy.arange(5)
with jax.check_tracer_leaks():
jax.vmap(model)(x)
Exception: Leaked trace MainTrace(1,BatchTrace). Leaked tracer(s):
Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([0, 1, 2, 3, 4], dtype=int32)
batch_dim = 0
This BatchTracer with object id 133963186386912 was created on line:
<ipython-input-14-2a0f2d1c7a01>:12 (<cell line: 12>)
<BatchTracer 133963186386912> is referred to by <Model 133963190222512>.cache
<Model 133963190222512> is referred to by <dict 133965063269504>['model']
<dict 133965063269504> is referred to by <frame 133964059882208>
<frame 133964059882208> is referred to by <list 133963186878784>[0]
<list 133963186878784> is referred to by <FramesList 133963188017072>._frames
<FramesList 133963188017072> is referred to by <frame 95552919519184>
<frame 95552919519184> is referred to by <generator 133963188390256>
我怀疑无论你的
model
对象是什么,它可能正在做与此类似的事情