Colab、Jax 和 GPU:为什么单元执行需要 60 秒,而 %%timeit 说只需要 70 毫秒?

问题描述 投票:0回答:1

作为分形项目的基础,我尝试使用 Jax 库在 Google Colab 上使用 GPU 计算。

我在所有加速器上使用 Mandelbrot 作为模型,但遇到了问题。

当我使用

%%timeit

 命令测量计算我的 GPU 函数所需的时间(与模型笔记本中相同)时,时间完全合理,并且符合预期结果 - 
70 到 80 毫秒

但实际上

跑步%%timeit

需要大约
整整一分钟。 (默认情况下,它连续运行该函数 7 次并报告平均值 - 但即使这样也需要不到一秒的时间。)

同样,当我在单元格中运行该函数并输出结果(6 兆像素图像)时,单元格大约需要 60 秒才能完成 - 执行一个据称只需要 70-80 毫秒的函数。

似乎有些东西正在产生大量的开销,而且似乎也随着计算量的增加而扩展——例如当函数包含 1,000 次迭代计算时,

%%timeit

 表示需要 71 毫秒,而实际上需要 60 秒,但只有 20 次迭代时,
%%timeit
 表示需要 10 毫秒,而实际上需要大约 10 秒。

我粘贴了下面的代码,但是

这里是 Colab 笔记本本身的链接——任何人都可以制作副本,连接到“T4 GPU”实例,然后自己运行它来查看。

import math import numpy as np import matplotlib.pyplot as plt import jax assert len(jax.devices("gpu")) == 1 def run_jax_kernel(c, fractal): z = c for i in range(1000): z = z**2 + c diverged = jax.numpy.absolute(z) > 2 diverging_now = diverged & (fractal == 1000) fractal = jax.numpy.where(diverging_now, i, fractal) return fractal run_jax_gpu_kernel = jax.jit(run_jax_kernel, backend="gpu") def run_jax_gpu(height, width): mx = -0.69291874321833995150613818345974774914923989808007473759199 my = 0.36963080032727980808623018005116209090839988898368679237704 zw = 4 / 1e3 y, x = jax.numpy.ogrid[(my-zw/2):(my+zw/2):height*1j, (mx-zw/2):(mx+zw/2):width*1j] c = x + y*1j fractal = jax.numpy.full(c.shape, 1000, dtype=np.int32) return np.asarray(run_jax_gpu_kernel(c, fractal).block_until_ready())
生成图像大约需要一分钟:

fig, ax = plt.subplots(1, 1, figsize=(15, 10)) ax.imshow(run_jax_gpu(2000, 3000));
大约需要一分钟报告该函数只需要 70-80 毫秒执行:

%%timeit -o run_jax_gpu(2000, 3000)
    
python jupyter-notebook google-colaboratory jax
1个回答
0
投票
首先要意识到的是,

%timeit

会多次执行你的代码,然后返回每次运行的平均次数。它将执行的次数由第一次运行的时间动态确定。

要认识到的第二件事是 JAX 代码是即时 (JIT) 编译的,这意味着在第一次执行任何特定函数时,您将产生一次性编译成本。许多因素都会影响编译成本,但使用大型

for

 循环(例如 1000 次或更多重复)的函数往往编译速度非常慢,因为 JAX 在将操作传递给 XLA 之前展开这些循环。

将这些放在一起,您就会明白为什么您会观察到这样的计时:在

%timeit

 下,您的第一次运行会导致非常长的编译,而后续运行则非常快。结果平均时间被打印出来,而且非常短。

当您运行一次代码来绘制结果时,您主要看到的是编译时间。因为它不会通过多次调用函数来摊销,所以编译时间很长。

解决方案是避免在函数中编写 Python

for

 循环,以避免较长的编译时间:一种可能是使用 
lax.fori_loop
,它允许您编写迭代计算,而不会造成巨大的编译时间损失,尽管与 
for
 循环解决方案相比,它会在 GPU 上产生运行时损失,因为操作是按顺序执行的,而不是由编译器并行执行。

© www.soinside.com 2019 - 2024. All rights reserved.