在 JAX 中进行“移位”矩阵乘法的更快方法

问题描述 投票:0回答:1

我有两个数组 f 和 g,f 是 N × T × J 维,f 是 T × J 维。我正在尝试在 JAX 中计算以下内容(对于所有 0<=t

请注意,如果 t-a<0 I’d like it to default to 0. What would be the fastest approach?

现在,我创建一个所有可能索引的列表,将相关索引中计算的两个数组按元素相乘,然后将它们相加:

import jax.numpy as jnp

all_indices = jnp.array([(θ, t, a) for θ in range(N) for t in range(T) for a in range(J)])
θ_idx, t_idx, a_idx = all_indices[:, 0], all_indices[:, 1], all_indices[:, 2]
tma_idx = jnp.maximum(t_idx - a_idx, 0)

unrolled = f[θ_idx, t_idx, a_idx] * g[tma_idx, a_idx]
s = unrolled.reshape(N, T, J).sum(axis=(0,2))

这似乎不是特别有效或优雅,我希望有更好的解决方案。

python for-loop matrix vectorization jax
1个回答
0
投票

我怀疑实现这一点的最佳方法是首先移动 2D 矩阵,然后通过

einsum
执行完全缩减。例如:

t = jnp.arange(T)[:, None]
a = jnp.arange(J)
g_shifted = g[jnp.maximum(t - a, 0), a]
s = jnp.einsum("ntj,tj->t", f, g_shifted)

与原始解决方案相比,这将

g
中的索引操作数量减少了
T
倍,并且依赖于高效的einsum操作来计算最终结果。

© www.soinside.com 2019 - 2024. All rights reserved.