我尝试在 GPU 上使用 JAX 执行批量矩阵乘法,并注意到形状 (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) 相乘比相乘快约 3 倍f64 时为 (1000, 1000, 3, 25) @ (1000, 1000, 25, 1),f32 时为 ~5x。 考虑到在 cpu 上 JAX 或 NumPy 都没有显示这种行为,而在 GPU 上 CuPy 没有显示这种行为,如何解释这种差异。 我在 NVIDIA RTX A5000 上使用 JAX: 0.4.32 运行此程序(并在 Tesla T4 上获得类似的结果),要重现的代码:
import numpy as np
import cupy as cp
from cupyx.profiler import benchmark
from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
rng = np.random.default_rng()
x = np.arange(5, 55, 5)
GPU 时序:
dtype = cp.float64
timings_cp = []
for i in range(5, 55, 5):
a = cp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
b = cp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
timings_cp.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))
dtype = jnp.float64
timings_jax_gpu = []
with jax.default_device(jax.devices('gpu')[0]):
for i in range(5, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
func = jax.jit(lambda a, b: a@b)
timings_jax_gpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))
plt.figure()
plt.plot(x, [i.gpu_times.mean() for i in timings_cp], label="CuPy")
plt.plot(x, [i.gpu_times.mean() for i in timings_jax_gpu], label="JAX GPU")
plt.legend()
这些特定形状的计时:
dtype = jnp.float64
with jax.default_device(jax.devices('gpu')[0]):
a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
func = jax.jit(lambda a, b: a@b)
print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())
a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())
给予
f64:
0.01453789699935913
0.004859122595310211
f32:
0.005860503035545349
0.001209742688536644
CPU时序:
timings_np = []
for i in range(5, 55, 5):
a = rng.random((1000, 1000, 3, i))
b = rng.random((1000, 1000, i, 1))
timings_np.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))
timings_jax_cpu = []
with jax.default_device(jax.devices('cpu')[0]):
for i in range(5, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, i)))
b = jnp.array(rng.random((1000, 1000, i, 1)))
func = jax.jit(lambda a, b: a@b)
timings_jax_cpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))
plt.figure()
plt.plot(x, [i.cpu_times.mean() for i in timings_np], label="NumPy")
plt.plot(x, [i.cpu_times.mean() for i in timings_jax_cpu], label="JAX CPU")
plt.legend()
差异似乎来自编译器针对较小尺寸发出
kInput
融合,以及针对较大尺寸发出 kLoop
融合。您可以在此源评论中了解这些效果:https://github.com/openxla/xla/blob/e6b6e61b29cc439350a6ad2f9d39535cb06011e5/xla/hlo/ir/hlo_instruction.h#L639-L656
编译器可能会使用一些启发式方法在两者之间进行选择,并且对于您的特定问题来说,这种启发式方法在边界上并不是最优的。您可以通过输出操作的编译 HLO 来看到这一点:
a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f64[1000,1000,3,25]{3,2,1,0}, f64[1000,1000,25,1]{3,2,1,0})->f64[1000,1000,3,1]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="a02cbfe0fda9d44e2bd23462363b6cc0"}
%scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
%scalar_rhs = f64[] parameter(1)
%scalar_lhs = f64[] parameter(0)
ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
}
%fused_reduce (param_0.7: f64[1000,1000,3,25], param_1.6: f64[1000,1000,25,1]) -> f64[1000,1000,3] {
%param_0.7 = f64[1000,1000,3,25]{3,2,1,0} parameter(0)
%param_1.6 = f64[1000,1000,25,1]{3,2,1,0} parameter(1)
%bitcast.28.5 = f64[1000,1000,25]{2,1,0} bitcast(f64[1000,1000,25,1]{3,2,1,0} %param_1.6)
%broadcast.2.5 = f64[1000,1000,3,25]{3,2,1,0} broadcast(f64[1000,1000,25]{2,1,0} %bitcast.28.5), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
%multiply.2.3 = f64[1000,1000,3,25]{3,2,1,0} multiply(f64[1000,1000,3,25]{3,2,1,0} %param_0.7, f64[1000,1000,3,25]{3,2,1,0} %broadcast.2.5)
%constant_4 = f64[] constant(0)
ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,25]{3,2,1,0} %multiply.2.3, f64[] %constant_4), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
}
ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,25], Arg_1.2.0: f64[1000,1000,25,1]) -> f64[1000,1000,3,1] {
%Arg_1.2.0 = f64[1000,1000,25,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
%Arg_0.1.0 = f64[1000,1000,3,25]{3,2,1,0} parameter(0), metadata={op_name="a"}
%loop_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,25]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,25,1]{3,2,1,0} %Arg_1.2.0), kind=kLoop, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %loop_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
}
a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
%scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
%scalar_rhs = f64[] parameter(1)
%scalar_lhs = f64[] parameter(0)
ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
}
%fused_reduce (param_0.5: f64[1000,1000,3,35], param_1.2: f64[1000,1000,35,1]) -> f64[1000,1000,3] {
%param_0.5 = f64[1000,1000,3,35]{3,2,1,0} parameter(0)
%param_1.2 = f64[1000,1000,35,1]{3,2,1,0} parameter(1)
%bitcast.28.3 = f64[1000,1000,35]{2,1,0} bitcast(f64[1000,1000,35,1]{3,2,1,0} %param_1.2)
%broadcast.2.3 = f64[1000,1000,3,35]{3,2,1,0} broadcast(f64[1000,1000,35]{2,1,0} %bitcast.28.3), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
%multiply.2.1 = f64[1000,1000,3,35]{3,2,1,0} multiply(f64[1000,1000,3,35]{3,2,1,0} %param_0.5, f64[1000,1000,3,35]{3,2,1,0} %broadcast.2.3)
%constant_3 = f64[] constant(0)
ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,35]{3,2,1,0} %multiply.2.1, f64[] %constant_3), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
}
ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,35], Arg_1.2.0: f64[1000,1000,35,1]) -> f64[1000,1000,3,1] {
%Arg_1.2.0 = f64[1000,1000,35,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
%Arg_0.1.0 = f64[1000,1000,3,35]{3,2,1,0} parameter(0), metadata={op_name="a"}
%input_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,35]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,35,1]{3,2,1,0} %Arg_1.2.0), kind=kInput, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %input_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
}
这是一个观察编译器关于大小的决定的脚本:
for size in range(10, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, size)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, size, 1)), dtype=dtype)
hlo_text = jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text()
print(f"{size=} {'kLoop' in hlo_text=}")
size=10 'kLoop' in hlo_text=True
size=15 'kLoop' in hlo_text=True
size=20 'kLoop' in hlo_text=True
size=25 'kLoop' in hlo_text=True
size=30 'kLoop' in hlo_text=True
size=35 'kLoop' in hlo_text=False
size=40 'kLoop' in hlo_text=False
size=45 'kLoop' in hlo_text=False
size=50 'kLoop' in hlo_text=False
除了在 https://github.com/openxla/xla 报告这一点之外,我没有任何建议;可能是编译器启发式选择发出
kLoop
与 kInput
需要一些额外的逻辑.