任务: 我有两个长度为 N 的矩阵 A、B 列表。对于每对元素 A[i]、B[i] 形状,矩阵乘积是明确定义的,但是对于 $0,\dots, N- 中的每个 i 1$ 形状可以不同。因此,我无法将它们堆叠在数组中。形状是静态的。
我想达到与以下相同的结果:
out = [None] * length(A)
for i, a, b in enumerate(zip(A,B)):
out[i] = a @ b
但是,我想与 jax 并行执行此操作。最好的选择是 vmap,但这是不可能的,因为形状不同。
在这里我将讨论我所知道的解决方案以及为什么它们不令人满意。
编写for循环然后jit它。 这将使编译时间在长度 N 上呈超线性增长。这不好,因为我在运行计算之前知道输入和输出的所有形状,所以我希望编译时间恒定(提供形状列表)。
使用 jax 中的 fori_loop 原语。 在文档中,有以下内容:
fori_loop 的语义由这个 Python 实现给出:
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
但是,我的情况更简单:我不需要关心迭代中的 val。这意味着
fori
是连续的。虽然我的情况是平行的。因此,应该可以做得更好。
用零填充,使用vmap,读取结果。 我不控制形状的分布,因此如果只有一个形状很大,可能会导致内存爆炸。
使用 lax.map 在这里(jax.lax.map 和 jax.vmap 之间的权衡是什么?)我读到了以下内容:
lax.map 解决方案通常会很慢,因为它总是按顺序执行,迭代之间不可能进行融合/并行化。
所以我不知道该怎么办。 谢谢!
我认为你最好的方法将类似于你原来的公式,尽管你可以避免预先分配
out
列表:
out = [a @ b for a, b in zip(A, B)]
由于 JAX 的异步调度,如果您在 GPU 这样的加速器上运行它,操作将尽可能并行执行。
您提出的所有其他解决方案要么由于静态形状限制而无法工作,要么会产生开销,这将使它们在实践中比这种更直接的方法更糟糕。