PyMC 使用 JAX 通过 GPU 计算sample_posterior_predictive

问题描述 投票:0回答:1

我可以使用 GPU 通过 m.sampling_jax.sample_numpyro_nuts 加速 PyMC 模型采样,并且它在我的 Linux 环境中运行良好。然而,当对后验预测进行采样以扩展到我的 idata 时,我无法使用我的 cpu,这使其成为我的模型的瓶颈。

# Sampling
if gpu_available:
    idata = pm.sampling_jax.sample_numpyro_nuts(draws=draws_def, 
                                                tune=tune_def, 
                                                target_accept=targ_acc_def, 
                                                chain_method='vectorized', 
                                                idata_kwargs={"log_likelihood": True})

else:
    idata = pm.sample(draws=draws_def, tune=tune_def, target_accept=targ_acc_def, idata_kwargs={"log_likelihood": True})


idata.extend(pm.sample_posterior_predictive(idata, var_names=["y_obs"]))
bayesian jax pymc
1个回答
0
投票

您可以将

compile_kwargs=dict(mode="JAX")
传递给
sample_posterior_predictive

© www.soinside.com 2019 - 2024. All rights reserved.