在 GPU 上使用 JAX 进行批量矩阵乘法,矩阵越大,速度更快

问题描述 投票:0回答:1

我尝试在 GPU 上使用 JAX 执行批量矩阵乘法,并注意到形状 (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) 相乘比相乘快约 3 倍f64 时为 (1000, 1000, 3, 25) @ (1000, 1000, 25, 1),f32 时为 ~5x。 考虑到在 cpu 上 JAX 或 NumPy 都没有显示这种行为,而在 GPU 上 CuPy 没有显示这种行为,如何解释这种差异。 我在 NVIDIA RTX A5000 上使用 JAX: 0.4.32 运行此程序(并在 Tesla T4 上获得类似的结果),要重现的代码:

import numpy as np
import cupy as cp
from cupyx.profiler import benchmark
from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

rng = np.random.default_rng()

x = np.arange(5, 55, 5)

GPU 时序:

dtype = cp.float64
timings_cp = []
for i in range(5, 55, 5):
    a = cp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
    b = cp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
    timings_cp.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))

dtype = jnp.float64
timings_jax_gpu = []
with jax.default_device(jax.devices('gpu')[0]):
    for i in range(5, 55, 5):
        a = jnp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
        b = jnp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
        func = jax.jit(lambda a, b: a@b)
        timings_jax_gpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))

plt.figure()
plt.plot(x, [i.gpu_times.mean() for i in timings_cp], label="CuPy")
plt.plot(x, [i.gpu_times.mean() for i in timings_jax_gpu], label="JAX GPU")
plt.legend()

enter image description here

这些特定形状的计时:

dtype = jnp.float64
with jax.default_device(jax.devices('gpu')[0]):
    a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
    b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
    func = jax.jit(lambda a, b: a@b)
    print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())

    a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
    b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
    print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())

给予

f64:
0.01453789699935913
0.004859122595310211

f32:

0.005860503035545349
0.001209742688536644

CPU时序:

timings_np = []
for i in range(5, 55, 5):
    a = rng.random((1000, 1000, 3, i))
    b = rng.random((1000, 1000, i, 1))
    timings_np.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))

timings_jax_cpu = []
with jax.default_device(jax.devices('cpu')[0]):
    for i in range(5, 55, 5):
        a = jnp.array(rng.random((1000, 1000, 3, i)))
        b = jnp.array(rng.random((1000, 1000, i, 1)))
        func = jax.jit(lambda a, b: a@b)
        timings_jax_cpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))

plt.figure()
plt.plot(x, [i.cpu_times.mean() for i in timings_np], label="NumPy")
plt.plot(x, [i.cpu_times.mean() for i in timings_jax_cpu], label="JAX CPU")
plt.legend()

enter image description here

python numpy cuda jax cupy
1个回答
0
投票

差异似乎来自编译器针对较小尺寸发出

kInput
融合,以及针对较大尺寸发出
kLoop
融合。您可以在此源评论中了解这些效果:https://github.com/openxla/xla/blob/e6b6e61b29cc439350a6ad2f9d39535cb06011e5/xla/hlo/ir/hlo_instruction.h#L639-L656

编译器可能会使用一些启发式方法在两者之间进行选择,并且对于您的特定问题来说,这种启发式方法在边界上并不是最优的。您可以通过输出操作的编译 HLO 来看到这一点:

a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f64[1000,1000,3,25]{3,2,1,0}, f64[1000,1000,25,1]{3,2,1,0})->f64[1000,1000,3,1]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="a02cbfe0fda9d44e2bd23462363b6cc0"}

%scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
  %scalar_rhs = f64[] parameter(1)
  %scalar_lhs = f64[] parameter(0)
  ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
}

%fused_reduce (param_0.7: f64[1000,1000,3,25], param_1.6: f64[1000,1000,25,1]) -> f64[1000,1000,3] {
  %param_0.7 = f64[1000,1000,3,25]{3,2,1,0} parameter(0)
  %param_1.6 = f64[1000,1000,25,1]{3,2,1,0} parameter(1)
  %bitcast.28.5 = f64[1000,1000,25]{2,1,0} bitcast(f64[1000,1000,25,1]{3,2,1,0} %param_1.6)
  %broadcast.2.5 = f64[1000,1000,3,25]{3,2,1,0} broadcast(f64[1000,1000,25]{2,1,0} %bitcast.28.5), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
  %multiply.2.3 = f64[1000,1000,3,25]{3,2,1,0} multiply(f64[1000,1000,3,25]{3,2,1,0} %param_0.7, f64[1000,1000,3,25]{3,2,1,0} %broadcast.2.5)
  %constant_4 = f64[] constant(0)
  ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,25]{3,2,1,0} %multiply.2.3, f64[] %constant_4), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
}

ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,25], Arg_1.2.0: f64[1000,1000,25,1]) -> f64[1000,1000,3,1] {
  %Arg_1.2.0 = f64[1000,1000,25,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
  %Arg_0.1.0 = f64[1000,1000,3,25]{3,2,1,0} parameter(0), metadata={op_name="a"}
  %loop_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,25]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,25,1]{3,2,1,0} %Arg_1.2.0), kind=kLoop, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
  ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %loop_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
}
a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
%scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
  %scalar_rhs = f64[] parameter(1)
  %scalar_lhs = f64[] parameter(0)
  ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
}

%fused_reduce (param_0.5: f64[1000,1000,3,35], param_1.2: f64[1000,1000,35,1]) -> f64[1000,1000,3] {
  %param_0.5 = f64[1000,1000,3,35]{3,2,1,0} parameter(0)
  %param_1.2 = f64[1000,1000,35,1]{3,2,1,0} parameter(1)
  %bitcast.28.3 = f64[1000,1000,35]{2,1,0} bitcast(f64[1000,1000,35,1]{3,2,1,0} %param_1.2)
  %broadcast.2.3 = f64[1000,1000,3,35]{3,2,1,0} broadcast(f64[1000,1000,35]{2,1,0} %bitcast.28.3), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
  %multiply.2.1 = f64[1000,1000,3,35]{3,2,1,0} multiply(f64[1000,1000,3,35]{3,2,1,0} %param_0.5, f64[1000,1000,3,35]{3,2,1,0} %broadcast.2.3)
  %constant_3 = f64[] constant(0)
  ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,35]{3,2,1,0} %multiply.2.1, f64[] %constant_3), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
}

ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,35], Arg_1.2.0: f64[1000,1000,35,1]) -> f64[1000,1000,3,1] {
  %Arg_1.2.0 = f64[1000,1000,35,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
  %Arg_0.1.0 = f64[1000,1000,3,35]{3,2,1,0} parameter(0), metadata={op_name="a"}
  %input_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,35]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,35,1]{3,2,1,0} %Arg_1.2.0), kind=kInput, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
  ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %input_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
}

这是一个观察编译器关于大小的决定的脚本:

for size in range(10, 55, 5):
  a = jnp.array(rng.random((1000, 1000, 3, size)), dtype=dtype)
  b = jnp.array(rng.random((1000, 1000, size, 1)), dtype=dtype)
  hlo_text = jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text()
  print(f"{size=} {'kLoop' in hlo_text=}")
size=10 'kLoop' in hlo_text=True
size=15 'kLoop' in hlo_text=True
size=20 'kLoop' in hlo_text=True
size=25 'kLoop' in hlo_text=True
size=30 'kLoop' in hlo_text=True
size=35 'kLoop' in hlo_text=False
size=40 'kLoop' in hlo_text=False
size=45 'kLoop' in hlo_text=False
size=50 'kLoop' in hlo_text=False

除了在 https://github.com/openxla/xla 报告这一点之外,我没有任何建议;可能是编译器启发式选择发出

kLoop
kInput
需要一些额外的逻辑.

© www.soinside.com 2019 - 2024. All rights reserved.