Jax 追踪器泄漏

问题描述 投票:0回答:1

当我开始使用

vmap
时,我遇到了 JAX 跟踪器泄漏问题。我有两个函数
grad_loss
及其批处理等效函数
grad_loss_batch

@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = model(ti, yi[0])
    return jnp.mean((yi - y_pred) ** 2)

@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
    y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
    return jnp.mean((yi - y_pred) ** 2)

通过运行一些初步实验:

ts = jnp.linspace(0.0, 1.0, 10)
ys = jax.lax.stop_gradient(jnp.ones_like(ts)[..., None])
...

# seem to work ok 
with jax.checking_leaks():
   loss, grads = grad_loss(model_dde, ts, ys)

# Silently side-effecting, no error ?
loss2, grads2 = grad_loss_batch(model_dde, ts, ys[None, ...])

# Jax tracer leak
with jax.check_tracer_leaks():
    loss2, grads2 = grad_loss_batch(model_dde, ts, ys[None, ...])

错误产生:

(jdb) grad_loss_batch(model_dde, ts, ys[None, ...])
*** Exception: Leaked trace MainTrace(3,BatchTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<BatchTrace(level=3/0)> with
  val = Array([[1.]], dtype=float32)
  batch_dim = 0
This BatchTracer with object id 132573553507072 was created on line:
  /home/monsel/Desktop/dev_diffrax/mwe.py:69 (grad_loss_batch)
<BatchTracer 132573553507072> is referred to by <list 132573501535936>[0]
<list 132573501535936> is referred to by <PjitParams 132573554102464>[0]
<PjitParams 132573554102464> is referred to by <InferParamsCacheEntry 132573553590720>

我想知道是否有人可以提供指示以开始调试问题。我知道问题取决于

model
及其定义。我宁愿把它作为一个抽象,因为它是一个相当大的代码库,用户必须深入研究。尽管如此,我可以提供一个“大型”MWE 来重现这一点。

jax
1个回答
0
投票

当函数有副作用时,通常会发生跟踪器泄漏,例如在缓存中存储的值持续超出了跟踪函数的范围。这是此类示踪剂泄漏的一个最小示例:

import jax

class Model:
  def __call__(self, x):
    self.cache = x  # persistently storing a traced value -> tracer leak
    return x
jax.tree_util.register_dataclass(Model, data_fields=[], meta_fields=[])

model = Model()
x = jax.numpy.arange(5)

with jax.check_tracer_leaks():
  jax.vmap(model)(x)
Exception: Leaked trace MainTrace(1,BatchTrace). Leaked tracer(s):

Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1, 2, 3, 4], dtype=int32)
  batch_dim = 0
This BatchTracer with object id 133963186386912 was created on line:
  <ipython-input-14-2a0f2d1c7a01>:12 (<cell line: 12>)
<BatchTracer 133963186386912> is referred to by <Model 133963190222512>.cache
<Model 133963190222512> is referred to by <dict 133965063269504>['model']
<dict 133965063269504> is referred to by <frame 133964059882208>
<frame 133964059882208> is referred to by <list 133963186878784>[0]
<list 133963186878784> is referred to by <FramesList 133963188017072>._frames
<FramesList 133963188017072> is referred to by <frame 95552919519184>
<frame 95552919519184> is referred to by <generator 133963188390256>

我怀疑无论你的

model
对象是什么,它可能正在做与此类似的事情

© www.soinside.com 2019 - 2024. All rights reserved.