jax 不同形状矩阵对的并行乘法

问题描述 投票:0回答:1

任务: 我有两个长度为 N 的矩阵 A、B 列表。对于每对元素 A[i]、B[i] 形状,矩阵乘积是明确定义的,但是对于 $0,\dots, N- 中的每个 i 1$ 形状可以不同。因此,我无法将它们堆叠在数组中。形状是静态的。

我想达到与以下相同的结果:

out = [None] * length(A)
for i, a, b in enumerate(zip(A,B)):
   out[i] = a @ b

但是,我想与 jax 并行执行此操作。最好的选择是 vmap,但这是不可能的,因为形状不同。

在这里我将讨论我所知道的解决方案以及为什么它们不令人满意。

  1. 编写for循环然后jit它。 这将使编译时间在长度 N 上呈超线性增长。这不好,因为我在运行计算之前知道输入和输出的所有形状,所以我希望编译时间恒定(提供形状列表)。

  2. 使用 jax 中的 fori_loop 原语。 在文档中,有以下内容:

fori_loop 的语义由这个 Python 实现给出:


def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

但是,我的情况更简单:我不需要关心迭代中的 val。这意味着

fori
是连续的。虽然我的情况是平行的。因此,应该可以做得更好。

  1. 用零填充,使用vmap,读取结果。 我不控制形状的分布,因此如果只有一个形状很大,可能会导致内存爆炸。

  2. 使用 lax.map 在这里(jax.lax.map 和 jax.vmap 之间的权衡是什么?)我读到了以下内容:

lax.map 解决方案通常会很慢,因为它总是按顺序执行,迭代之间不可能进行融合/并行化。

所以我不知道该怎么办。 谢谢!

python jax
1个回答
0
投票

我认为你最好的方法将类似于你原来的公式,尽管你可以避免预先分配

out
列表:

out = [a @ b for a, b in zip(A, B)]

由于 JAX 的异步调度,如果您在 GPU 这样的加速器上运行它,操作将尽可能并行执行。

您提出的所有其他解决方案要么由于静态形状限制而无法工作,要么会产生开销,这将使它们在实践中比这种更直接的方法更糟糕。

© www.soinside.com 2019 - 2024. All rights reserved.