我想使用贝叶斯线性回归对两个可观察量
O_2
和O_3
进行建模:O_2 ~ m_2 * O_1 + q_2
和O_3 ~ m_3 * O_2 + q_3
。我想通过O_2
将它们结合起来。我正在 Python 中使用 PyMC
框架。我的代码如下所示:
with pm.Model() as model:
# Note: only O_1 is mutable
O_1_data = pm.Data('O_1', data['O_1'], mutable=True)
O_2_data = pm.Data('O_2', data['O_2'], mutable=False)
O_3_data = pm.Data('O_3', data['O_3'], mutable=False)
# Weak priors for regression coefficients
q_2 = pm.Normal('q_2', mu=0, sigma=20)
m_2 = pm.Normal('m_2', mu=0, sigma=20)
sigma_2 = pm.HalfNormal('sigma_2', sigma=std_O2)
q_3 = pm.Normal('q_3', mu=0, sigma=20)
m_3 = pm.Normal('m_3', mu=0, sigma=20)
sigma_3 = pm.HalfNormal('sigma_3', sigma=std_O3)
# Likelihood for O_2
mu_O2 = q_2 + m_2 * O_1_data
O_2_obs = pm.Normal('O_2_obs', mu=mu_O2, sigma=sigma_2, observed=O_2_data)
# Likelihood for O_3 using O_2 predictions
mu_O3 = q_3 + m_3 * O_2_obs
O_3_obs = pm.Normal('O_3_obs', mu=mu_O3, sigma=sigma_3, observed=O_3_data)
# Run MCMC simulation
trace = pm.sample(2000, tune=1000, return_inferencedata=True, random_seed=1503)
现在我想对
O_3
进行后验预测样本采样,给定新的 O_1
值(在模型拟合期间看不到)。按照this链接中的建议,这是我尝试过的:
# Example new O1 test data point:
new_O_1_value = np.array([2.3])
# Generate posterior predictive samples for new O1 values
with model:
# Update with new O1 values
pm.set_data({'O_1': new_O_1_value})
# Generate posterior predictive samples for O3
posterior_predictive = pm.sample_posterior_predictive(trace, predictions=True, random_seed=1503)
# Extract O3 posterior predictive samples
O3_post_pred = posterior_predictive.predictions['O_3_obs']
与我的期望相反,
O3_post_pred
具有形状(# chains, # Monte Carlo samples, # rows in training data)
(即在我的情况下为(4, 2000, 95)
,但我相信它应该是(# chains, # Monte Carlo samples, # rows in new data)
(即在我的情况下(4, 2000, 1)
,因为new_O_1_value只有一个值。)
问题:我做错了什么? 任何帮助表示赞赏。
问题是,在定义
O_2
和 O_3
两个模型的可能性时,我没有传递 shape
参数,如果没有明确设置,它默认为用于拟合的数据的形状模型,即使在传递新输入进行预测时也是如此。正确的代码如下:
O_2_obs = pm.Normal('O_2_obs', mu=mu_O2, sigma=sigma_2, observed=O_2_data, shape=O_1.shape)
...
O_3_obs = pm.Normal('O_3_obs', mu=mu_O3, sigma=sigma_3, observed=O_3_data, shape=O_2_obs.shape)
其余都一样。