Jax 实现类似 Torch 的“Scatter”功能

问题描述 投票:0回答:1

出于图形学习的目的,我正在尝试实现一个全局总和批处理函数,该函数将大小为 (n x d) 的批处理图形表示“x”和相应的批次向量 (n x 1) 作为输入。然后我想计算每个批次的所有图形表示的总和。这是一个图形表示:torch's scatter function

这是我目前的尝试:

def global_sum_pool(x, batch):
    graph_reps = []
    i = 0
    n = jnp.max(batch)
    while True:
        ind = jnp.where(batch == i, True, False).reshape(-1, 1)
        ind = jnp.tile(ind, x.shape[1])
        x_ind = jnp.where(ind == True, x, 0.0)
        graph_reps.append(jnp.sum(x_ind, axis=0))
        if i == n:
            break
        i += 1
    return jnp.array(graph_reps)

我在线上遇到以下异常

if i == n

jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function make_step at /venvs/jax_env/lib/python3.11/site-packages/equinox/_jit.py:37 for jit. 

我理解这是因为在编译时,Jax 事先并不知道“batch”数组的最大值,因此无法分配内存。有谁知道解决方法或不同的实现?

python graph deep-learning pytorch jax
1个回答
0
投票

您应该使用 JAX 的内置

for
 运算符,而不是通过 
scatter
循环来实现此操作。最方便的接口是
Array.at
语法。如果我正确理解你的目标,它可能看起来像这样:

import jax.numpy as jnp
import numpy as np

# Generate some data
num_batches = 4
n = 10
d = 3
x = np.random.randn(n, d)
ind = np.random.randint(low=0, high=num_batches, size=(n,))

#Compute the result with jax.lax.scatter
result = jnp.zeros((num_batches, d)).at[ind].add(x)
print(result.shape)
# (4, 3)
© www.soinside.com 2019 - 2024. All rights reserved.