Pathfinder 变分推断#

Pathfinder [Zhang et al., 2021] 是一种变分推断算法,可以从贝叶斯模型的后验分布中生成样本。它与广泛使用的 ADVI 算法相比更具优势。在大型问题上,它的扩展性应优于大多数 MCMC 算法,包括动态 HMC(即 NUTS),但代价是对后验分布的估计存在更大的偏差。有关该算法的详细信息,请参阅 arxiv 预印本

PyMC 的 Pathfinder 实现现在已使用 PyTensor 本地集成。Pathfinder 实现可以通过 pymc-extras 访问,可以通过以下方式安装

pip install git+https://github.com/pymc-devs/pymc-extras

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pymc_extras as pmx

print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v5.20.1

首先,定义您的 PyMC 模型。这里,我们使用 8-schools 模型。

# Data of the Eight Schools Model
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

with pm.Model() as model:
    mu = pm.Normal("mu", mu=0.0, sigma=10.0)
    tau = pm.HalfCauchy("tau", 5.0)

    z = pm.Normal("z", mu=0, sigma=1, shape=J)
    theta = pm.Deterministic("theta", mu + tau * z)
    obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)

接下来,我们调用 pmx.fit() 并传入我们要使用的算法。

rng = np.random.default_rng(123)
with model:
    idata_ref = pm.sample(target_accept=0.9, random_seed=rng)
    idata_path = pmx.fit(
        method="pathfinder",
        jitter=12,
        num_draws=1000,
        random_seed=123,
    )
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, tau, z]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.

Pathfinder Results                          
                                            
  No. model parameters     10               
                                            
  Configuration:                            
  num_draws_per_path       1000             
  history size (maxcor)    7                
  max iterations           1000             
  ftol                     1.00e-05         
  gtol                     1.00e-08         
  max line search          1000             
  jitter                   12               
  epsilon                  1.00e-08         
  ELBO draws               10               
                                            
  LBFGS Status:                             
  CONVERGED                4                
  L-BFGS iterations        mean 22 ± std 6  
                                            
  Path Status:                              
  SUCCESS                  4                
  ELBO argmax              mean 8 ± std 9   
                                            
  Importance Sampling:                      
  Method                   psis             
  Pareto k                 0.75             
                                            
  Timing (seconds):                         
  Compile                  4.53             
  Compute                  0.09             
  Total                    4.62             

就像 pymc.sample() 一样,这会返回一个包含后验样本的 idata。请注意,由于这些样本不是来自 MCMC 链,因此无法以常规方式评估收敛性。

az.plot_forest(
    [idata_ref, idata_path],
    var_names=["~z"],
    model_names=["ref", "path"],
    combined=True,
);
../_images/60ecfb8b14bd4a05cb6efde36c7ca4d7d9e2ab499a24ea97f995262ce81573ea.png

参考文献#

[1]

Lu Zhang, Bob Carpenter, Andrew Gelman, and Aki Vehtari. Pathfinder: parallel quasi-newton variational inference. arXiv preprint arXiv:2108.03782, 2021.

作者#

  • 作者:Thomas Wiecki,于 2022 年 10 月 11 日 (pymc-examples#429)

  • 重新执行笔记本:Reshama Shaikh,于 2023 年 2 月 5 日

  • 错误修复:Chris Fonnesbeck,于 2024 年 7 月 17 日

  • 更新到 PyMC 实现:Michael Cao,于 2025 年 2 月 13 日

  • 文本更新:Chris Fonnesbeck,于 2025 年 2 月 19 日

水印#

%load_ext watermark
%watermark -n -u -v -iv -w -p xarray
Last updated: Wed Feb 19 2025

Python implementation: CPython
Python version       : 3.12.9
IPython version      : 8.32.0

xarray: 2025.1.2

arviz      : 0.19.0
numpy      : 1.26.4
matplotlib : 3.10.0
pymc_extras: 0.2.3
pymc       : 5.20.1

Watermark: 2.5.0

许可声明#

本示例库中的所有笔记本均根据 MIT 许可证 提供,该许可证允许修改和再分发用于任何用途,前提是保留版权和许可声明。

引用 PyMC 示例#

要引用此笔记本,请使用 Zenodo 为 pymc-examples 仓库提供的 DOI。

重要提示

许多笔记本改编自其他来源:博客、书籍……在这种情况下,您也应该引用原始来源。

另请记住引用您的代码使用的相关库。

这是一个 bibtex 格式的引用模板

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

渲染后可能如下所示