使用 JAX 和 Numba 更快地采样#
PyMC 可以通过 PyTensor 将其模型编译到各种执行后端,包括
C
JAX
Numba
默认情况下,PyMC 使用 C 后端,然后由基于 Python 的采样器调用。
但是,通过编译到其他后端,我们可以使用用 Python 以外的其他语言编写的采样器,这些采样器调用 PyMC 模型而没有任何 Python 开销。
对于 JAX 后端,可以使用 NumPyro 和 BlackJAX NUTS 采样器。要使用这些采样器,您必须安装 numpyro
和 blackjax
。它们都可以通过 conda/mamba 获得:mamba install -c conda-forge numpyro blackjax
。
对于 Numba 后端,有 Rust 编写的 Nutpie 采样器。要使用此采样器,您需要安装 nutpie
:mamba install -c conda-forge nutpie
。
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
rng = np.random.default_rng(seed=42)
print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v5.6.0
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
我们将使用一个简单的概率 PCA 模型作为示例。
def build_toy_dataset(N, D, K, sigma=1):
x_train = np.zeros((D, N))
w = rng.normal(
0.0,
2.0,
size=(D, K),
)
z = rng.normal(0.0, 1.0, size=(K, N))
mean = np.dot(w, z)
for d in range(D):
for n in range(N):
x_train[d, n] = rng.normal(mean[d, n], sigma)
print("True principal axes:")
print(w)
return x_train
N = 5000 # number of data points
D = 2 # data dimensionality
K = 1 # latent dimensionality
data = build_toy_dataset(N, D, K)
True principal axes:
[[ 0.60943416]
[-2.07996821]]
plt.scatter(data[0, :], data[1, :], color="blue", alpha=0.1)
plt.axis([-10, 10, -10, 10])
plt.title("Simulated data set")
Text(0.5, 1.0, 'Simulated data set')

with pm.Model() as PPCA:
w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered())
z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
使用 Python NUTS 采样器采样#
%%time
with PPCA:
idata_pymc = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, z]
100.00% [8000/8000 00:28<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 29 seconds.
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/arviz/utils.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
numba_fn = numba.jit(**self.kwargs)(self.function)
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
CPU times: user 19.7 s, sys: 971 ms, total: 20.7 s
Wall time: 47.6 s
使用 NumPyro JAX NUTS 采样器采样#
%%time
with PPCA:
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Compiling...
Compilation time = 0:00:00.619901
Sampling...
Sampling time = 0:00:11.469112
Transforming variables...
Transformation time = 0:00:00.118111
CPU times: user 40.5 s, sys: 6.66 s, total: 47.2 s
Wall time: 12.9 s
使用 BlackJAX NUTS 采样器采样#
%%time
with PPCA:
idata_blackjax = pm.sample(nuts_sampler="blackjax")
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
Compiling...
Compilation time = 0:00:00.607693
Sampling...
Sampling time = 0:00:02.132882
Transforming variables...
Transformation time = 0:00:08.410508
CPU times: user 35.4 s, sys: 6.73 s, total: 42.1 s
Wall time: 11.6 s
使用 Nutpie Rust NUTS 采样器采样#
%%time
with PPCA:
idata_nutpie = pm.sample(nuts_sampler="nutpie")
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/util.py:501: FutureWarning: The tag attribute observations is deprecated. Use model.rvs_to_values[rv] instead
warnings.warn(
100.00% [8000/8000 00:09<00:00 Chains in warmup: 0, Divergences: 0]
CPU times: user 37.6 s, sys: 3.34 s, total: 41 s
Wall time: 16.1 s
许可证声明#
此示例库中的所有 notebook 均根据 MIT 许可证 提供,该许可证允许修改和再分发以用于任何用途,前提是保留版权和许可证声明。
引用 PyMC 示例#
要引用此 notebook,请使用 Zenodo 为 pymc-examples 仓库提供的 DOI。
重要提示
许多 notebook 都改编自其他来源:博客、书籍…… 在这种情况下,您也应该引用原始来源。
另请记住引用您的代码使用的相关库。
这是一个 bibtex 格式的引用模板
@incollection{citekey,
author = "<notebook authors, see above>",
title = "<notebook title>",
editor = "PyMC Team",
booktitle = "PyMC examples",
doi = "10.5281/zenodo.5654871"
}
渲染后可能如下所示