是否可以通过均值对 scipy multivariate_normal 进行向量化?

问题描述 投票:0回答:1

我有多元正态分布。我在马尔可夫链采样器中使用它,它利用矢量化选项。这意味着当它可以同时请求多个不同均值向量的 PDF 对数时,它是最佳的。下面是一个最小的例子。第一段代码请求单个数据点和单个均值向量,它工作得很好。第二段代码请求多个均值向量。对于第二段代码,我收到错误

ValueError: Array 'mean' must be a vector of length 600.

有没有内置的方法可以解决这个问题,或者我必须使用 for 循环吗?

import numpy as np
from scipy.stats import multivariate_normal

np.random.seed(42)
cov = np.diag(np.ones(6))


mu = np.random.normal(0, 1, 6)
x = np.random.normal(0, 1, 6)
print(multivariate_normal.logpdf(x, mean=mu, cov=cov))

mu = np.random.normal(0, 1, (100, 6))
x = np.random.normal(0, 1, (100, 6))
print(multivariate_normal.logpdf(x, mean=mu, cov=cov))

编辑:我刚刚找到了单个数据点的解决方案。由于正态分布的均值和数据可以互换,因此以下将产生正确的结果

mu = np.random.normal(0, 1, (100, 6))
x = np.random.normal(0, 1, 6)
print(multivariate_normal.logpdf(mu, mean=x, cov=cov))
python scipy
1个回答
0
投票

scipy.stats
不提供此功能,但您也许可以使用支持矢量化的 JAX 实现。 JAX 是鸭子类型的,因此您可以将其用作需要 python 数组的代码的输入,但有一些限制(例如,JAX 没有就地修改)。

这是一个最小的例子:

import jax
import jax.random as jr
from jax.scipy.stats import multivariate_normal

key = jr.key(0)
samples = jr.multivariate_normal(key, jnp.zeros(3), jnp.eye(3), shape=(1000,))

get_logpdfs = jax.vmap(multivariate_normal.logpdf, in_axes=(0, None, None))
logpdfs = get_logpdfs(samples, jnp.zeros(3), jnp.eye(3))

print(samples.shape, logpdfs.shape)
© www.soinside.com 2019 - 2024. All rights reserved.