作为分形项目的基础,我尝试使用 Jax 库在 Google Colab 上使用 GPU 计算。
我在所有加速器上使用 Mandelbrot 作为模型,但遇到了问题。
当我使用%%timeit
命令测量计算我的 GPU 函数所需的时间(与模型笔记本中相同)时,时间完全合理,并且符合预期结果 -70 到 80 毫秒。 但实际上
跑步%%timeit
需要大约整整一分钟。 (默认情况下,它连续运行该函数 7 次并报告平均值 - 但即使这样也需要不到一秒的时间。) 同样,当我在单元格中运行该函数并输出结果(6 兆像素图像)时,单元格大约需要 60 秒才能完成 - 执行一个据称只需要 70-80 毫秒的函数。
似乎有些东西正在产生大量的开销,而且似乎也随着计算量的增加而扩展——例如当函数包含 1,000 次迭代计算时,
%%timeit
表示需要 71 毫秒,而实际上需要 60 秒,但只有 20 次迭代时,
%%timeit
表示需要 10 毫秒,而实际上需要大约 10 秒。我粘贴了下面的代码,但是
这里是 Colab 笔记本本身的链接——任何人都可以制作副本,连接到“T4 GPU”实例,然后自己运行它来查看。
import math
import numpy as np
import matplotlib.pyplot as plt
import jax
assert len(jax.devices("gpu")) == 1
def run_jax_kernel(c, fractal):
z = c
for i in range(1000):
z = z**2 + c
diverged = jax.numpy.absolute(z) > 2
diverging_now = diverged & (fractal == 1000)
fractal = jax.numpy.where(diverging_now, i, fractal)
return fractal
run_jax_gpu_kernel = jax.jit(run_jax_kernel, backend="gpu")
def run_jax_gpu(height, width):
mx = -0.69291874321833995150613818345974774914923989808007473759199
my = 0.36963080032727980808623018005116209090839988898368679237704
zw = 4 / 1e3
y, x = jax.numpy.ogrid[(my-zw/2):(my+zw/2):height*1j, (mx-zw/2):(mx+zw/2):width*1j]
c = x + y*1j
fractal = jax.numpy.full(c.shape, 1000, dtype=np.int32)
return np.asarray(run_jax_gpu_kernel(c, fractal).block_until_ready())
生成图像大约需要一分钟:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))
ax.imshow(run_jax_gpu(2000, 3000));
大约需要一分钟报告该函数只需要 70-80 毫秒执行:
%%timeit -o
run_jax_gpu(2000, 3000)
%timeit
会多次执行你的代码,然后返回每次运行的平均次数。它将执行的次数由第一次运行的时间动态确定。要认识到的第二件事是 JAX 代码是即时 (JIT) 编译的,这意味着在第一次执行任何特定函数时,您将产生一次性编译成本。许多因素都会影响编译成本,但使用大型
for
循环(例如 1000 次或更多重复)的函数往往编译速度非常慢,因为 JAX 在将操作传递给 XLA 之前展开这些循环。将这些放在一起,您就会明白为什么您会观察到这样的计时:在
%timeit
下,您的第一次运行会导致非常长的编译,而后续运行则非常快。结果平均时间被打印出来,而且非常短。当您运行一次代码来绘制结果时,您主要看到的是编译时间。因为它不会通过多次调用函数来摊销,所以编译时间很长。
解决方案是避免在函数中编写 Python
for
循环,以避免较长的编译时间:一种可能是使用
lax.fori_loop
,它允许您编写迭代计算,而不会造成巨大的编译时间损失,尽管与 for
循环解决方案相比,它会在 GPU 上产生运行时损失,因为操作是按顺序执行的,而不是由编译器并行执行。