我有多个 Nx3 点,并且我从其相应的多元高斯中顺序为每个点生成一个新值,每个点都有 1x3 均值和 3x3 cov。所以,我总共有数组:Nx3 点数组、Nx3 均值数组和 Nx3x3 cov 数组。
我只看到如何使用经典的 for 循环来做到这一点:
import numpy as np
from scipy.stats import multivariate_normal
# Generate example data
N = 5 # Small number for minimal example, can be increased for real use case
points = np.random.rand(N, 3)
means = np.random.rand(N, 3)
covs = np.array([np.eye(3) for _ in range(N)]) # Identity matrices as example covariances
# Initialize an array to store the PDF values
pdf_values = np.zeros(N)
# Loop over each point, mean, and covariance matrix
for i in range(N):
pdf_values[i] = multivariate_normal.pdf(points[i], mean=means[i], cov=covs[i])
print("Points:\n", points)
print("Means:\n", means)
print("Covariances:\n", covs)
print("PDF Values:\n", pdf_values)
有什么办法可以加快速度吗?我尝试将所有内容直接传递给 multivariate_normal.pdf,但也从似乎不受支持的文档中传递(与为 Nx3 点生成值的更简单的情况不同,但具有相同的均值和协方差。
也许有些实现不是来自 scipy?
我可能抱有太大希望,但不知何故,我希望有一种更简单的方法来加快速度,并避免直接使用 Pythonic 循环在大量数据中使用此 for 循环进行迭代。
生成数据:
import numpy as np
from scipy.stats import multivariate_normal, wishart
# Generate example data
N = 5 # Small number for minimal example, can be increased for real use case
p = 3 # dimension
np.random.seed(10)
points = np.random.rand(N, p)
means = np.random.rand(N, p)
covs = wishart.rvs(p, np.eye(p),N, 2) # Identity matrices as example covariances
# Initialize an array to store the PDF values
pdf_values = np.zeros(N)
# Loop over each point, mean, and covariance matrix
for i in range(N):
pdf_values[i] = multivariate_normal.pdf(points[i], mean=means[i], cov=covs[i])
pdf_values
array([0.03356053, 0.03167125, 0.08042358, 0.04351325, 0.1328082 ])
如果您觉得输入方程式很无聊,并且您不介意重复自己,请使用`:
multivariate_normal.pdf(np.linalg.solve(np.linalg.cholesky(covs), points - means), np.zeros(p))
array([0.03369268, 0.05241376, 0.04136256, 0.05917483, 0.02996323])
否则我建议将函数写出来,因为
multivariate_normal
会进行胆汁分解,但我们已经这样做了。
y = np.linalg.solve(np.linalg.cholesky(covs), points - means)
np.exp((-0.5* (y**2 + np.log(2*np.pi))).sum(1))
array([0.03369268, 0.05241376, 0.04136256, 0.05917483, 0.02996323])