我遇到了 Jax 的问题,如果我不解决它,我将重写整个 20000 行的应用程序。
我有一个非机器学习应用程序,它依赖于 pytree 来存储数据,并且 pytree 很深 - 大约 6-7 层数据存储(class1 存储 class2,并且存储 class3 的数组等)
我使用 python 列表来存储 pytree,并希望对它们进行 vmap,但事实证明 jax 无法对列表进行 vmap。
(因此,一种解决方案是将每个数据类重写为结构化数组,并从那里开始工作,可能将所有 6-7 层数据放入一个巨型数组中)
有没有办法避免重写?有没有办法将 pytree 类存储在 vmappable 状态,以便一切都像以前一样工作?
如果有帮助的话,我已经用 flax.struct.dataclass 标记了我的类。
jax.vmap
被设计为与数组结构模式一起使用,听起来你有一个结构数组模式。从您的描述来看,听起来您有一系列嵌套结构,如下所示:
import jax
import jax.numpy as jnp
from flax.struct import dataclass
@dataclass
class Params:
x: jax.Array
y: jax.Array
@dataclass
class AllParams:
p: list[Params]
params_list = [AllParams([Params(4, 2), Params(4, 3)]),
AllParams([Params(3, 5), Params(2, 4)]),
AllParams([Params(3, 2), Params(6, 3)])]
然后你有一个要应用于列表中每个元素的函数;像这样的:
def some_func(params):
a, b = params.p
return a.x * b.y - b.x * a.y
[some_func(params) for params in params_list]
[4, 2, -3]
但是正如您所发现的,如果您尝试使用
vmap
执行此操作,则会收到错误:
jax.vmap(some_func)(params_list)
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
问题在于
vmap
分别对传递给它的列表或 pytree 的每个条目进行操作,而不是对列表的元素进行操作。
为了解决这个问题,您通常可以将数据结构从结构数组转换为数组结构,然后对其应用
vmap
。例如:
params_array = jax.tree.map(lambda *vals: jnp.array(vals), *params_list)
print(params_array)
AllParams(p=[ Params(x=Array([4, 3, 3], dtype=int32), y=Array([2, 5, 2], dtype=int32)), Params(x=Array([4, 2, 6], dtype=int32), y=Array([3, 4, 3], dtype=int32)) ])
现在您的数据采用这种格式,
vmap
应该可以工作:
jax.vmap(some_func)(params_array)
Array([ 4, 2, -3], dtype=int32)
现在,假设列表中的每个数据类都具有相同的结构:如果不是,则
vmap
将不适用,因为根据设计,它必须映射具有相同结构的计算。