jax自下而上处理不完美的对象树

问题描述 投票:0回答:1

我有一个不完美、不完整的树,我想自下而上地处理它。 IE。 level-t 输出一些值,这些值将用于 level-(t+1) 的计算,但 level t+1 也包含一些固定值。计算涉及静态形状但不同大小的数组。

我知道如何处理每个级别。另外,对于整个计算,我知道如何预先计算计算期间需要的所有索引。但是,我不明白如何在 jax 中组织从输出级别到下一个级别的写入值。此操作是本地的,因此我不想每次都重新创建整个列表。如果输出可以存储在单个数组而不是列表中,我会使用就地覆盖,所以我不能。

接下来我将更详细地讨论发生的事情。我将仅使用整数求和,以便仅强调计算结构,因此我对具有不同形状张量的一般非交换运算感兴趣。因此,一些简化是行不通的。

只是蟒蛇

def bin_(a, b):
    return a + b

def process_level(do, pairs):
    return [do(*pair) for pair in pairs]

def write_level_(what_write, where_write, seq):
    for k, v in zip(where_write, what_write):
        seq[k] = v

a, b, c, d, e, f = tuple(range(6))
seq = [None, None, e, None, f, None, a, b, c, d]

level_idx = [[6,7,8,9],[2,3,4,5],[0,1]]
level_p_idx = [[3,5],[0,1]]

for l, p in zip(level_idx, level_p_idx):
    pairs = zip(l[::2],l[1::2])
    l_out = process_level(bin_, pairs)
    write_level_(l_out, p, seq)

因此, write_level_ 改变了原地的东西,因此是不可接受的。但是,我不知道该怎么办。可以注意到,该问题与 jax.lax.scan 类似。这是正确的,但不完全是:扫描假设完美的树。任何建议表示赞赏。我想我只是不知道有关 jax 的基本知识如何修改,所以我自己无法弄清楚。

jax
1个回答
0
投票

通过稍微修改(制作

seq
和数组,将就地列表更新更改为就地数组更新,并返回更新后的数组),您可以使此代码与
jax.jit
兼容。例如:

import jax

def bin_(a, b):
    return a + b

def process_level(do, pairs):
    return [do(*pair) for pair in pairs]

def write_level_(what_write, where_write, seq):
    for k, v in zip(where_write, what_write):
        seq = seq.at[k].set(v)
    return seq

a, b, c, d, e, f = tuple(range(6))
seq = jax.numpy.array([-1, -1, e, -1, f, -1, a, b, c, d])

level_idx = [[6,7,8,9],[2,3,4,5],[0,1]]
level_p_idx = [[3,5],[0,1]]

@jax.jit
def f(seq, level_idx, level_p_idx):
    for l, p in zip(level_idx, level_p_idx):
        pairs = zip(l[::2],l[1::2])
        l_out = process_level(bin_, pairs)
        seq = write_level_(l_out, p, seq)
    return seq

seq = f(seq, level_idx, level_p_idx)
print(seq)
# [ 5  9  4 13  5 17  0  1  2  3]

我认为这是一般树结构所能做到的最好的;如果节点的子节点数量和形状相同,您可以将

for
循环替换为
lax.scan
之类的内容,但这取决于您的最终目标,这可能是可能的,也可能是不可能的。

© www.soinside.com 2019 - 2024. All rights reserved.