实现 RandomVariable 分布#
本指南概述了如何为 PyMC 实现分布。它专为希望向库中添加新分布的开发人员而设计。用户不会意识到所有这些复杂性,而应该使用诸如 ~pymc.CustomDist
等辅助方法。
PyMC Distribution
构建于 PyTensor 的 RandomVariable
之上,并实现了 logp
、logcdf
、icdf
和 support_point
方法,以及其他初始化和验证助手。最值得注意的是 shape/dims/observed
kwargs、替代参数化和默认 transform
。
以下是实现新分布所需步骤的摘要清单。每个部分将在下面展开
创建新的
RandomVariable
Op
实现相应的
Distribution
类为新的
RandomVariable
添加测试为
logp
/logcdf
/icdf
和support_point
方法添加测试为新的
Distribution
编写文档。
本指南不试图解释 Distributions
当前实现背后的原理,提供的细节仅在有助于实现新的“标准”分布的范围内。
1. 创建新的 RandomVariable
Op
#
RandomVariable
负责实现随机抽样方法。RandomVariable
还负责参数广播和形状推断。
在创建新的 RandomVariable
之前,请确保 NumPy 库
中尚未提供它。如果已提供,则应首先将其添加到 PyTensor 库,然后再导入到 PyMC 库中。
此外,可能并非总是需要实现新的 RandomVariable
。例如,如果新的 Distribution
只是现有 Distribution
的特殊参数化。 OrderedLogistic
和 OrderedProbit
就是这种情况,它们只是 Categorical
分布的特殊参数化。
以下代码片段说明了如何创建新的 RandomVariable
from pytensor.tensor.var import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from typing import List, Tuple
# Create your own `RandomVariable`...
class BlahRV(RandomVariable):
name: str = "blah"
# Provide a numpy-style signature for this RV, which indicates
# the number and core dimensionality of each input and output.
signature: "(),()->()"
# The NumPy/PyTensor dtype for this RV (e.g. `"int32"`, `"int64"`).
# The standard in the library is `"int64"` for discrete variables
# and `"floatX"` for continuous variables
dtype: str = "floatX"
# A pretty text and LaTeX representation for the RV
_print_name: Tuple[str, str] = ("blah", "\\operatorname{blah}")
# If you want to add a custom signature and default values for the
# parameters, do it like this. Otherwise this can be left out.
def __call__(self, loc=0.0, scale=1.0, **kwargs) -> TensorVariable:
return super().__call__(loc, scale, **kwargs)
# This is the Python code that produces samples. Its signature will always
# start with a NumPy `RandomState` object, then the distribution
# parameters, and, finally, the size.
@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
loc: np.ndarray,
scale: np.ndarray,
size: Tuple[int, ...],
) -> np.ndarray:
return scipy.stats.blah.rvs(loc, scale, random_state=rng, size=size)
# Create the actual `RandomVariable` `Op`...
blah = BlahRV()
一些需要牢记的重要事项
rng_fn
方法内部的所有内容都是纯 Python 代码(输入也是如此),并且不应使用其他PyTensor
符号运算。随机方法应使用 NumPyRandomGenerator
rng
,以使样本可重现。非默认
RandomVariable
维度将通过size
kwarg 进入rng_fn
。rng_fn
将必须考虑这一点以获得正确的输出。size
是 NumPy 和 SciPy 使用的规范,并且对于单变量分布,其工作方式类似于 PyMCshape
,但对于多变量分布则不同。对于多变量分布,size
排除支持维度,而结果TensorVariable
或ndarray
的shape
包括支持维度。有关更多上下文,请查看 维度笔记本。PyTensor
可以自动推断单变量RandomVariable
的输出形状。对于多变量分布,必须在新RandomVariable
类中实现方法_supp_shape_from_params
。此方法返回给定 RV 参数的支持维度。在某些情况下,可以从其参数之一的形状中导出,在这种情况下,助手pytensor.tensor.random.utils.supp_shape_from_ref_param_shape()
可以像在DirichletMultinomialRV
中一样使用。在其他情况下,参数值(而不是它们的形状)可能会决定分布的支持形状,就像在~pymc.distributions.multivarite._LKJCholeskyCovRV
中发生的那样。在更简单的情况下,它们可能是常数。可以在新
rng_fn
内部使用其他 PyTensor 和 PyMCRandomVariables
的rng_fn
类方法
。例如,如果要实现负 HalfNormalRandomVariable
,则rng_fn
可以简单地返回- halfnormal.rng_fn(rng, scale, size)
。
注意:除了 size
之外,PyMC API 还提供了 shape
、dims
和 observed
作为定义分布维度的替代方法,但这由 Distribution
负责处理,并且不应需要任何额外的更改。
为了快速测试您的新 RandomVariable
Op
是否正常工作,您可以调用带有必要参数的 Op
,然后在返回的对象上调用 draw
# blah = pytensor.tensor.random.uniform in this example
# multiple calls with the same seed should return the same values
pm.draw(blah([0, 0], [1, 2], size=(10, 2)), random_seed=1)
# array([[0.83674527, 0.76593773],
# [0.00958496, 1.85742402],
# [0.74001876, 0.6515534 ],
# [0.95134629, 1.23564938],
# [0.41460156, 0.33241175],
# [0.66707807, 1.62134924],
# [0.20748312, 0.45307477],
# [0.65506507, 0.47713784],
# [0.61284429, 0.49720329],
# [0.69325978, 0.96272673]])
2. 从 PyMC 基础 Distribution
类继承#
在实现新的 RandomVariable
Op
之后,就可以在新 PyMC Distribution
中使用它了。PyMC 以非常 函数式 的方式工作,并且 distribution
类的主要作用是添加 PyMC API 功能并将相关方法组织在一起。在实践中,它们负责
将 (调度)
rv_op
类与相应的support_point
、logp
、logcdf
和icdf
方法链接起来。定义标准转换(对于连续分布),将有界变量域(例如,正线)转换为无界域(即,实线),许多采样器更喜欢后者。
验证分布的参数化并将非符号输入(即,数字文字或 NumPy 数组)转换为符号变量。
将多个替代参数化转换为以
RandomVariable
定义的标准参数化。
以下是示例的延续
import pytensor.tensor as pt
from pymc.distributions.continuous import PositiveContinuous
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.shape_utils import rv_size_is_none
# Subclassing `PositiveContinuous` will dispatch a default `log` transformation
class Blah(PositiveContinuous):
# This will be used by the metaclass `DistributionMeta` to dispatch the
# class `logp` and `logcdf` methods to the `blah` `Op` defined in the last line of the code above.
rv_op = blah
# dist() is responsible for returning an instance of the rv_op.
# We pass the standard parametrizations to super().dist
@classmethod
def dist(cls, param1, param2=None, alt_param2=None, **kwargs):
param1 = pt.as_tensor_variable(param1)
if param2 is not None and alt_param2 is not None:
raise ValueError("Only one of param2 and alt_param2 is allowed.")
if alt_param2 is not None:
param2 = 1 / alt_param2
param2 = pt.as_tensor_variable(param2)
# The first value-only argument should be a list of the parameters that
# the rv_op needs in order to be instantiated
return super().dist([param1, param2], **kwargs)
# support_point returns a symbolic expression for the stable point from which to start sampling
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`.
# This is typically a "representative" point such as the the mean or mode.
def support_point(rv, size, param1, param2):
support_point, _ = pt.broadcast_arrays(param1, param2)
if not rv_size_is_none(size):
support_point = pt.full(size, support_point)
return support_point
# Logp returns a symbolic expression for the elementwise log-pdf or log-pmf evaluation
# of the variable given the `value` of the variable and the parameters `param1` ... `paramN`.
def logp(value, param1, param2):
logp_expression = value * (param1 + pt.log(param2))
# A switch is often used to enforce the distribution support domain
bounded_logp_expression = pt.switch(
pt.gt(value >= 0),
logp_expression,
-np.inf,
)
# We use `check_parameters` for parameter validation. After the default expression,
# multiple comma-separated symbolic conditions can be added.
# Whenever a bound is invalidated, the returned expression raises an error
# with the message defined in the optional `msg` keyword argument.
return check_parameters(
bounded_logp_expression,
param2 >= 0,
msg="param2 >= 0",
)
# logcdf works the same way as logp. For bounded variables, it is expected to return
# `-inf` for values below the domain start and `0` for values above the domain end.
def logcdf(value, param1, param2):
...
def icdf(value, param1, param2):
...
一些注意事项
分布至少应从
Discrete
或Continuous
继承。对于后者,存在更具体的子类:PositiveContinuous
、UnitContinuous
、BoundedContinuous
、CircularContinuous
、SimplexContinuous
,它们为变量指定默认转换。 如果您需要指定一次性自定义转换,您还可以创建一个_default_transform
调度函数,就像为LKJCholeskyCov
所做的那样。如果分布没有相应的
rng_fn
实现,则仍应创建RandomVariable
以引发NotImplementedError
。例如,Flat
就是这种情况。在这种情况下,将需要提供support_point
方法,因为如果没有rng_fn
,PyMC 将无法回退到随机抽取以用作 MCMC 的初始点。如上所述,PyMC 以非常 函数式 的方式工作,并且
logp
、logcdf
、icdf
和support_point
方法中需要的所有信息都应通过RandomVariable
输入“携带”。您可以传递严格来说rng_fn
方法不需要但在这些方法中使用的数值参数。只需记住这是否会影响RandomVariable
的正确形状推断行为。logcdf
和icdf
方法不是必需的,但这是一个不错的加分项!目前,
support_point
方法中仅支持一个矩,并且可能是“高阶”矩最有用(即mean
>median
>mode
)…… 如果您正在处理离散分布,则可能需要截断矩。support_point
应返回随机变量的有效点(即,在该点评估时,它始终具有非零概率)创建
support_point
方法时,请注意size != None
并根据不一定用于计算矩的参数正确广播。例如,pm.Normal.dist(mu=0, sigma=np.arange(1, 6))
中的sigma
与矩无关,但可能仍然会告知形状。在这种情况下,support_point
应返回[mu, mu, mu, mu, mu]
。
为了快速检查事物是否正常工作,您可以尝试以下操作
import pymc as pm
from pymc.distributions.distribution import support_point
# pm.blah = pm.Normal in this example
blah = pm.blah.dist(mu=0, sigma=1)
# Test that the returned blah_op is still working fine
pm.draw(blah, random_seed=1)
# array(-1.01397228)
# Test the support_point method
support_point(blah).eval()
# array(0.)
# Test the logp method
pm.logp(blah, [-0.5, 1.5]).eval()
# array([-1.04393853, -2.04393853])
# Test the logcdf method
pm.logcdf(blah, [-0.5, 1.5]).eval()
# array([-1.17591177, -0.06914345])
3. 为新的 RandomVariable
添加测试#
新 RandomVariables
的测试主要位于 tests/distributions/test_*.py
中。大多数测试可以通过默认 BaseTestDistributionRandom
类容纳,该类为检查提供默认测试
预期输入通过
dist
classmethod
,通过check_pymc_params_match_rv_op
传递到rv_op
通过
check_pymc_draws_match_reference
返回预期的(精确)抽样形状变量推断正确,通过
check_rv_size
from pymc.testing import BaseTestDistributionRandom, seeded_scipy_distribution_builder
class TestBlah(BaseTestDistributionRandom):
pymc_dist = pm.Blah
# Parameters with which to test the blah pymc Distribution
pymc_dist_params = {"param1": 0.25, "param2": 2.0}
# Parameters that are expected to have passed as inputs to the RandomVariable op
expected_rv_op_params = {"param1": 0.25, "param2": 2.0}
# If the new `RandomVariable` is simply calling a `numpy`/`scipy` method,
# we can make use of `seeded_[scipy|numpy]_distribution_builder` which
# will prepare a seeded reference distribution for us.
reference_dist_params = {"mu": 0.25, "loc": 2.0}
reference_dist = seeded_scipy_distribution_builder("blah")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]
应为分布的每个可选参数化添加其他测试。在这种情况下,包含测试 check_pymc_params_match_rv_op
就足够了,因为只有这一点不同。
确保测试的替代参数值会导致关联的默认参数的值不同。例如,如果它只是倒数,则使用 1.0
进行测试不是很informative,因为转换也会返回 1.0
,并且我们不能(像)确定它是否正常工作。
class TestBlahAltParam2(BaseTestDistributionRandom):
pymc_dist = pm.Blah
# param2 is equivalent to 1 / alt_param2
pymc_dist_params = {"param1": 0.25, "alt_param2": 4.0}
expected_rv_op_params = {"param1": 0.25, "param2": 2.0}
tests_to_run = ["check_pymc_params_match_rv_op"]
也可以将自定义测试添加到类中,就像为 TestFlat
所做的那样。
关于 check_rv_size
测试的注意事项:#
可以通过添加可选的类属性 sizes_to_check
和 sizes_expected
为 check_rv_size
测试定义自定义输入大小(和预期输出形状)
sizes_to_check = [None, (1), (2, 3)]
sizes_expected = [(3,), (1, 3), (2, 3, 3)]
tests_to_run = ["check_rv_size"]
这通常是多变量分布所需要的。您可以在 TestDirichlet
中看到一个示例。
关于 check_pymcs_draws_match_reference
测试的注意事项#
check_pymcs_draws_match_reference
是一个非常简单的测试,用于测试给定相同输入和随机种子的 RandomVariable
的抽样与完全相同的 python 函数的抽样是否相等。检查了少量 (size=15
)。这不应该是随机数生成器正确性的测试。后一种测试(如果需要)可以在 pymc_random
和 pymc_random_discrete
方法的帮助下执行,这将对 RandomVariable.rng_fn
和参考 Python 函数之间执行昂贵的统计比较。仅当存在良好的独立生成器参考时,这种测试才有意义(即,不仅仅是 rng_fn
内部完成的 NumPy / SciPy 调用的相同组合)。
最后,当您的 rng_fn
所做的事情不仅仅是调用 NumPy 或 SciPy 方法时,您将需要设置一个等效的种子函数,以便与精确抽样进行比较(而不是依赖于 seeded_[scipy|numpy]_distribution_builder
)。您可以在 TestWeibull
中找到一个示例,其 rng_fn
返回 beta * np.random.weibull(alpha, size=size)
。
4. 为 logp
/ logcdf
/ icdf
方法添加测试#
logp
、logcdf
和 icdf
的测试主要使用 ~testing
中实现的助手 check_logp
、check_logcdf
、check_icdf
和 check_selfconsistency_discrete_logcdf
from pymc.testing import Domain, check_logp, check_logcdf, select_by_precision
R = Domain([-np.inf, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.inf])
Rplus = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 100, np.inf])
def test_blah():
check_logp(
pymc_dist=pm.Blah,
# Domain of the distribution values
domain=R,
# Domains of the distribution parameters
paramdomains={"mu": R, "sigma": Rplus},
# Reference scipy (or other) logp function
scipy_logp=lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
# Number of decimal points expected to match between the pymc and reference functions
decimal=select_by_precision(float64=6, float32=3),
# Maximum number of combinations of domain * paramdomains to test
n_samples=100,
)
check_logcdf(
pymc_dist=pm.Blah,
domain=R,
paramdomains={"mu": R, "sigma": Rplus},
scipy_logcdf=lambda value, mu, sigma: sp.norm.logcdf(value, mu, sigma),
decimal=select_by_precision(float64=6, float32=1),
n_samples=-1,
)
这些方法将在 domain
和 paramdomains
值的组合上执行网格评估,并检查 PyMC 方法和参考函数是否匹配。有一些值得牢记的细节
默认情况下,不比较
Domain
的第一个和最后一个值(边缘)(它们用于其他用途)。如果测试Domain
的边缘很重要,则可以重复边缘值。这是由Bool
完成的:Bool = Domain([0, 0, 1, 1], "int64")
有一些默认域(例如
R
和Rplus
),您可以使用它们来测试新分布,但是如果出于充分的理由(例如,当默认值导致太多极端的、不太可能的组合,而这些组合对于实现的正确性没有太多信息时),在测试函数内部创建自己的域也是完全可以的。默认情况下,测试 100 个
param
xparamdomain
组合的随机子集,以控制测试运行时长。测试闪亮的新分布时,您可以暂时设置n_samples=-1
以强制测试所有组合。这很重要,可以避免您的PR
在将来运行中导致意外失败,只要随机测试了一些错误的参数组合。在 GitHub 上,某些测试在
pytensor.config.floatX
标志"float64"
和"float32"
下运行两次。但是,参考 Python 函数将在纯“float64”环境中运行,这意味着参考和 PyMC 结果可能会有很大差异(例如,对于极端参数,下溢到-np.inf
)。因此,您应确保在两种状态下都在本地进行测试。一种快速而简便的方法是在文件顶部,紧接在import pytensor
之后,临时添加pytensor.config.floatX = "float32"
。记住也要设置n_samples=-1
以测试所有组合。测试输出将显示哪些确切的参数值导致失败。如果您确信您的实现是正确的,您可以选择使用select_by_precision
调整十进制精度,或调整测试的Domain
值。在极端情况下,您可以使用条件xfail
标记测试(如果只有子方法之一失败,则应将其分开,以便xfail
尽可能窄)
def test_blah_logp(self):
...
@pytest.mark.xfail(
condition=(pytensor.config.floatX == "float32"),
reason="Fails on float32 due to numerical issues",
)
def test_blah_logcdf(self):
...
5. 为 support_point
方法添加测试#
support_point
的测试使用函数 assert_support_point_is_expected
,该函数检查是否
矩返回
expected
值矩具有预期的大小和形状
矩具有有限的 logp
import pytest
from pymc.distributions import Blah
from pymc.testing import assert_support_point_is_expected
@pytest.mark.parametrize(
"param1, param2, size, expected",
[
(0, 1, None, 0),
(0, np.ones(5), None, np.zeros(5)),
(np.arange(5), 1, None, np.arange(5)),
(np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
],
)
def test_blah_support_point(param1, param2, size, expected):
with Model() as model:
Blah("x", param1=param1, param2=param2, size=size)
assert_support_point_is_expected(model, expected)
以下是一些值得牢记的细节
在必须手动相互广播参数的情况下,添加测试条件非常重要,如果您不这样做,测试条件将失败。执行此操作的直接方法是使使用的参数为标量,未使用的参数为向量(一次一个)和大小
None
。换句话说,请确保测试大小和广播的不同组合以涵盖这些情况。
6. 为新的 Distribution
编写文档#
新的分布应具有丰富的文档字符串,格式与先前实现的分布的格式相同。它通常看起来像这样
r"""Univariate blah distribution.
The pdf of this distribution is
.. math::
f(x \mid \param1, \param2) = \exp{x * (param1 + \log{param2})}
.. plot::
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import arviz as az
x = np.linspace(-5, 5, 1000)
params1 = [0., 0., 0., -2.]
params2 = [0.4, 1., 2., 0.4]
for param1, param2 in zip(params1, params2):
pdf = st.blah.pdf(x, param1, param2)
plt.plot(x, pdf, label=r'$\param1$ = {}, $\param2$ = {}'.format(param1, param2))
plt.xlabel('x', fontsize=12)
plt.ylabel('f(x)', fontsize=12)
plt.legend(loc=1)
plt.show()
======== ==========================================
Support :math:`x \in [0, \infty)`
======== ==========================================
Blah distribution can be parameterized either in terms of param2 or
alt_param2. The link between the two parametrizations is
given by
.. math::
\param2 = \dfrac{1}{\alt_param2}
Parameters
----------
param1: float
Interpretation of param1.
param2: float
Interpretation of param2 (param2 > 0).
alt_param2: float
Interpretation of alt_param2 (alt_param2 > 0) (alternative to param2).
Examples
--------
.. code-block:: python
with pm.Model():
x = pm.Blah('x', param1=0, param2=10)
"""
新的分布应在 docs
模块(例如,pymc/docs/api/distributions.continuous.rst
)中各自的 API 页面中引用。如果合适,应将新的笔记本示例添加到 pymc-examples,说明如何使用此分布以及它与其他用户更熟悉的分布的关系(和/或差异)。