在 Pytree 上存储和 jax.vmap()

问题描述 投票:0回答:1

我遇到了 Jax 的问题,如果我不解决它,我将重写整个 20000 行的应用程序。

我有一个非机器学习应用程序,它依赖于 pytree 来存储数据,并且 pytree 很深 - 大约 6-7 层数据存储(class1 存储 class2,并且存储 class3 的数组等)

我使用 python 列表来存储 pytree,并希望对它们进行 vmap,但事实证明 jax 无法对列表进行 vmap。

(因此,一种解决方案是将每个数据类重写为结构化数组,并从那里开始工作,可能将所有 6-7 层数据放入一个巨型数组中)

有没有办法避免重写?有没有办法将 pytree 类存储在 vmappable 状态,以便一切都像以前一样工作?

如果有帮助的话,我已经用 flax.struct.dataclass 标记了我的类。

jax
1个回答
0
投票

jax.vmap
被设计为与数组结构模式一起使用,听起来你有一个结构数组模式。从您的描述来看,听起来您有一系列嵌套结构,如下所示:

import jax
import jax.numpy as jnp
from flax.struct import dataclass

@dataclass
class Params:
  x: jax.Array
  y: jax.Array


@dataclass
class AllParams:
  p: list[Params]


params_list = [AllParams([Params(4, 2), Params(4, 3)]),
               AllParams([Params(3, 5), Params(2, 4)]),
               AllParams([Params(3, 2), Params(6, 3)])]

然后你有一个要应用于列表中每个元素的函数;像这样的:

def some_func(params):
  a, b = params.p
  return a.x * b.y - b.x * a.y

[some_func(params) for params in params_list]
[4, 2, -3]

但是正如您所发现的,如果您尝试使用

vmap
执行此操作,则会收到错误:

jax.vmap(some_func)(params_list)
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

问题在于

vmap
分别对传递给它的列表或 pytree 的每个条目进行操作,而不是对列表的元素进行操作。

为了解决这个问题,您通常可以将数据结构从结构数组转换为数组结构,然后对其应用

vmap
。例如:

params_array = jax.tree.map(lambda *vals: jnp.array(vals), *params_list)
print(params_array)
AllParams(p=[
  Params(x=Array([4, 3, 3], dtype=int32), y=Array([2, 5, 2], dtype=int32)),
  Params(x=Array([4, 2, 6], dtype=int32), y=Array([3, 4, 3], dtype=int32))
])

现在您的数据采用这种格式,

vmap
应该可以工作:

jax.vmap(some_func)(params_array)
Array([ 4,  2, -3], dtype=int32)

现在,假设列表中的每个数据类都具有相同的结构:如果不是,则

vmap
将不适用,因为根据设计,它必须映射具有相同结构的计算。

© www.soinside.com 2019 - 2024. All rights reserved.