为什么我的基于 JAX 的图像处理循环会随着时间的推移而变慢?如何优化 GPU 内存使用?

问题描述 投票:0回答:1

我正在使用 JAX 处理图像处理任务,并且遇到了循环性能随着时间的推移而显着下降的问题。具体来说,前几次迭代运行得很快,但随后的迭代速度明显减慢。此外,我注意到整个过程中 GPU 内存使用率仍然很高。

这是我的代码的简化版本:

@jax.jit
def raybyray(VoxelSpacing, VoxelNum, DV, P_camera, P_image, offset):
    # Function implementation
    pass

@jax.jit
def image_process(VoxelSpacing, VoxelNum, DV, P_camera, P_image, offset):
    # Function implementation
    pass

for i in range(10):
    %time drr_image = image_process(raybyray(VoxelSpacing, VoxelNum, DV, P_camera, P_image, offset))
CPU times: user 2.15 ms, sys: 0 ns, total: 2.15 ms
Wall time: 1.74 ms
CPU times: user 706 µs, sys: 0 ns, total: 706 µs
Wall time: 718 µs
CPU times: user 933 µs, sys: 0 ns, total: 933 µs
Wall time: 814 µs
CPU times: user 3.88 ms, sys: 0 ns, total: 3.88 ms
Wall time: 462 ms
CPU times: user 2.27 ms, sys: 1.03 ms, total: 3.31 ms
Wall time: 375 ms
CPU times: user 4.8 ms, sys: 66 µs, total: 4.86 ms
Wall time: 376 ms
CPU times: user 4.01 ms, sys: 87 µs, total: 4.1 ms
Wall time: 392 ms
CPU times: user 4.77 ms, sys: 0 ns, total: 4.77 ms
Wall time: 380 ms
CPU times: user 2.56 ms, sys: 0 ns, total: 2.56 ms
Wall time: 377 ms
CPU times: user 5.81 ms, sys: 0 ns, total: 5.81 ms
Wall time: 391 ms

我怀疑可能存在 GPU 内存泄漏或内存管理效率低下导致速度变慢。我尝试过使用 gc.collect(),但问题仍然存在。

问题:

什么可能导致我的循环性能下降? 如何优化 GPU 内存使用以防止这种速度下降? 是否有在 JAX 中管理 GPU 内存以避免此类问题的最佳实践? 任何见解或建议将不胜感激!

谢谢!

python memory gpu jax
1个回答
0
投票

我怀疑您被 JAX 的 异步调度误导了:对于前几次迭代,您不是测量运行时间,而是测量调度时间。然后调度队列填满,最后几次迭代正在测量先前调用的实际运行时间。

要解决此问题,您可以使用

jax.block_until_ready()
:

包装输出
for i in range(10):
    %time drr_image = jax.block_until_ready(
        image_process(raybyray(VoxelSpacing, VoxelNum, DV, P_camera, P_image, offset)))

有关有效 JAX 微基准测试应记住的更多信息,请参阅 JAX 常见问题解答:对 JAX 代码进行基准测试

© www.soinside.com 2019 - 2024. All rights reserved.