我有两个数组 f 和 g,f 是 N × T × J 维,f 是 T × J 维。我正在尝试在 JAX 中计算以下内容(对于所有 0<=t
请注意,如果 t-a<0 I’d like it to default to 0. What would be the fastest approach?
现在,我创建一个所有可能索引的列表,将相关索引中计算的两个数组按元素相乘,然后将它们相加:
import jax.numpy as jnp
all_indices = jnp.array([(θ, t, a) for θ in range(N) for t in range(T) for a in range(J)])
θ_idx, t_idx, a_idx = all_indices[:, 0], all_indices[:, 1], all_indices[:, 2]
tma_idx = jnp.maximum(t_idx - a_idx, 0)
unrolled = f[θ_idx, t_idx, a_idx] * g[tma_idx, a_idx]
s = unrolled.reshape(N, T, J).sum(axis=(0,2))
这似乎不是特别有效或优雅,我希望有更好的解决方案。
我怀疑实现这一点的最佳方法是首先移动 2D 矩阵,然后通过
einsum
执行完全缩减。例如:
t = jnp.arange(T)[:, None]
a = jnp.arange(J)
g_shifted = g[jnp.maximum(t - a, 0), a]
s = jnp.einsum("ntj,tj->t", f, g_shifted)
与原始解决方案相比,这将
g
中的索引操作数量减少了T
倍,并且依赖于高效的einsum操作来计算最终结果。