我正在使用 jax.profiler.start_trace 在 TPU 上分析 JAX 代码。我试图尽可能减少跟踪的持续时间(在单个 v5e 芯片上运行,...),但我仍然有超过 3/4 的跟踪部分被“跟踪缓冲区丢弃”隐藏,因为2GB protobuf 文件大小限制。但跟踪文件实际上约为 250 或 600MB,因此还有余量。我需要捕捉约 300 秒 如何避免这种情况?增加 protobuf 文件大小,...
我尝试使用 jax.profiler.start_server ,但没有收集到任何内容......它似乎适合较短的跟踪周期。
jax.profiler.start_server
本身不会留下痕迹。它允许您使用 Tensorboard UI 来启动跟踪 (https://jax.readthedocs.io/en/latest/profiling.html#manual-capture-via-tensorboard)。这可能是控制捕获秒数的好方法。
奇怪的是,您的跟踪是 < 1GB, yet it says you're hitting the 2GB limit. I can't comment to ask questions that would help debug, so I suggest filing an issue at https://github.com/jax-ml/jax/issues,我们可以在那里为您提供更多帮助。
作为一种解决方法,我建议捕获许多较小的痕迹,而不是一个大型的 300 秒痕迹。