我可以使用 GPU 通过 m.sampling_jax.sample_numpyro_nuts 加速 PyMC 模型采样,并且它在我的 Linux 环境中运行良好。然而,当对后验预测进行采样以扩展到我的 idata 时,我无法使用我的 cpu,这使其成为我的模型的瓶颈。
# Sampling
if gpu_available:
idata = pm.sampling_jax.sample_numpyro_nuts(draws=draws_def,
tune=tune_def,
target_accept=targ_acc_def,
chain_method='vectorized',
idata_kwargs={"log_likelihood": True})
else:
idata = pm.sample(draws=draws_def, tune=tune_def, target_accept=targ_acc_def, idata_kwargs={"log_likelihood": True})
idata.extend(pm.sample_posterior_predictive(idata, var_names=["y_obs"]))
您可以将
compile_kwargs=dict(mode="JAX")
传递给sample_posterior_predictive