arviz_stats.loo_moment_match#
- arviz_stats.loo_moment_match(data, loo_orig, log_prob_upars_fn, log_lik_i_upars_fn, upars=None, var_name=None, reff=None, max_iters=30, k_threshold=None, split=True, cov=False, pointwise=None)[source]#
Compute moment matching for problematic observations in PSIS-LOO-CV.
Adjusts the results of a previously computed Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) object by applying a moment matching algorithm to observations with high Pareto k diagnostic values. The moment matching algorithm iteratively adjusts the posterior draws in the unconstrained parameter space to better approximate the leave-one-out posterior.
The moment matching algorithm is described in [1] and the PSIS-LOO-CV method is described in [2] and [3].
- Parameters:
- data
xarray.DataTree
orInferenceData
Input data. It should contain the posterior and the log_likelihood groups.
- loo_orig
ELPDData
An existing ELPDData object from a previous
loo
result. Must contain pointwise Pareto k values (pointwise=True must have been used).- log_prob_upars_fn
Callable
[[xarray.DataArray
],xarray.DataArray
] A function that takes the unconstrained parameter draws and returns a
DataArray
containing the log probability density of the full posterior distribution evaluated at each unconstrained parameter draw. The returned DataArray must have dimensions chain, draw.- log_lik_i_upars_fn
Callable
[[xarray.DataArray
,int
],xarray.DataArray
] A function that takes the unconstrained parameter draws and the integer index i of the left-out observation. It should return a
DataArray
containing the log-likelihood of the left-out observation i evaluated at each unconstrained parameter draw. The returned DataArray must have dimensions chain, draw.- upars
xarray.DataArray
, optional Posterior draws transformed to the unconstrained parameter space. Must have chain and draw dimensions, plus one additional dimension containing all parameters. Parameter names can be provided as coordinate values on this dimension. If not provided, will attempt to use the unconstrained_posterior group from the input data if available.
- var_name
str
, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.
- reff: float, optional
Relative MCMC efficiency,
ess / n
i.e. number of effective samples divided by the number of actual samples. Computed from trace by default.- max_iters
int
, default 30 Maximum number of moment matching iterations for each problematic observation.
- k_threshold
float
, optional Threshold value for Pareto k values above which moment matching is applied. Defaults to \(\min(1 - 1/\log_{10}(S), 0.7)\), where S is the number of samples.
- splitbool, default
True
If True, only transform half of the draws and use multiple importance sampling to combine them with untransformed draws.
- covbool, default
False
If True, match the covariance structure during the transformation, in addition to the mean and marginal variances. Ignored if
split=False
.- pointwise: bool, optional
If True, the pointwise predictive accuracy will be returned. Defaults to
rcParams["stats.ic_pointwise"]
. Moment matching always requires pointwise data from loo_orig. This argument controls whether the returned object includes pointwise data.
- data
- Returns:
ELPDData
Object with the following attributes:
elpd: expected log pointwise predictive density
se: standard error of the elpd
p: effective number of parameters
n_samples: number of samples
n_data_points: number of data points
warning: True if the estimated shape parameter of Pareto distribution is greater than
good_k
.elp_i:
DataArray
with the pointwise predictive accuracy, only ifpointwise=True
pareto_k: array of Pareto shape values, only if
pointwise=True
good_k: For a sample size S, the threshold is computed as
min(1 - 1/log10(S), 0.7)
approx_posterior: True if approximate posterior was used.
Notes
The moment matching algorithm considers three affine transformations of the posterior draws. For a specific draw \(\theta^{(s)}\), a generic affine transformation includes a square matrix \(\mathbf{A}\) representing a linear map and a vector \(\mathbf{b}\) representing a translation such that
\[T : \theta^{(s)} \mapsto \mathbf{A}\theta^{(s)} + \mathbf{b} =: \theta^{*{(s)}}.\]The first transformation, \(T_1\), is a translation that matches the mean of the sample to its importance weighted mean given by
\[\mathbf{\theta^{*{(s)}}} = T_1(\mathbf{\theta^{(s)}}) = \mathbf{\theta^{(s)}} - \bar{\theta} + \bar{\theta}_w,\]where \(\bar{\theta}\) is the mean of the sample and \(\bar{\theta}_w\) is the importance weighted mean of the sample. The second transformation, \(T_2\), is a scaling that matches the marginal variances in addition to the means given by
\[\mathbf{\theta^{*{(s)}}} = T_2(\mathbf{\theta^{(s)}}) = \mathbf{v}^{1/2}_w \circ \mathbf{v}^{-1/2} \circ (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w,\]where \(\mathbf{v}\) and \(\mathbf{v}_w\) are the sample and weighted variances, and \(\circ\) denotes the pointwise product of the elements of two vectors. The third transformation, \(T_3\), is a covariance transformation that matches the covariance matrix of the sample to its importance weighted covariance matrix given by
\[\mathbf{\theta^{*{(s)}}} = T_3(\mathbf{\theta^{(s)}}) = \mathbf{L}_w \mathbf{L}^{-1} (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w,\]where \(\mathbf{L}\) and \(\mathbf{L}_w\) are the Cholesky decompositions of the covariance matrix and the weighted covariance matrix, respectively, e.g.,
\[\mathbf{LL}^T = \mathbf{\Sigma} = \frac{1}{S} \sum_{s=1}^S (\mathbf{\theta^{(s)}} - \bar{\theta}) (\mathbf{\theta^{(s)}} - \bar{\theta})^T\]and
\[\mathbf{L}_w \mathbf{L}_w^T = \mathbf{\Sigma}_w = \frac{\frac{1}{S} \sum_{s=1}^S w^{(s)} (\mathbf{\theta^{(s)}} - \bar{\theta}_w) (\mathbf{\theta^{(s)}} - \bar{\theta}_w)^T}{\sum_{s=1}^S w^{(s)}}.\]We iterate on \(T_1\) repeatedly and move onto \(T_2\) and \(T_3\) only if \(T_1\) fails to yield a Pareto-k statistic below the threshold.
References
[1]Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). Implicitly Adaptive Importance Sampling. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850.
[2]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[3]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
Examples
Moment matching can improve PSIS-LOO-CV estimates for observations with high Pareto k values without having to refit the model for each problematic observation. We will use the non-centered eight schools data which has 1 problematic observation:
In [1]: import arviz_base as az ...: from arviz_stats import loo, loo_moment_match ...: import numpy as np ...: import xarray as xr ...: from scipy import stats ...: ...: idata = az.load_arviz_data("non_centered_eight") ...: loo_orig = loo(idata, pointwise=True, var_name="obs") ...: loo_orig ...: Out[1]: Computed from 2000 posterior samples and 8 observations log-likelihood matrix. Estimate SE elpd_loo -30.72 1.33 p_loo 0.90 - There has been a warning during the calculation. Please check the results. ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.70] (good) 7 87.5% (0.70, 1] (bad) 1 12.5% (1, Inf) (very bad) 0 0.0%
For moment matching, we need the unconstrained parameters and two functions for the log probability and pointwise log-likelihood computations:
In [2]: posterior = idata.posterior ...: theta_t = posterior.theta_t.values ...: mu = posterior.mu.values[:, :, np.newaxis] ...: log_tau = np.log(posterior.tau.values)[:, :, np.newaxis] ...: ...: upars = np.concatenate([theta_t, mu, log_tau], axis=2) ...: param_names = [f"theta_t_{i}" for i in range(8)] + ["mu", "log_tau"] ...: ...: upars = xr.DataArray( ...: upars, ...: dims=["chain", "draw", "upars_dim"], ...: coords={ ...: "chain": posterior.chain, ...: "draw": posterior.draw, ...: "upars_dim": param_names ...: } ...: ) ...: ...: def log_prob_upars(upars): ...: theta_tilde = upars.sel(upars_dim=[f"theta_t_{i}" for i in range(8)]) ...: mu = upars.sel(upars_dim="mu") ...: log_tau = upars.sel(upars_dim="log_tau") ...: tau = np.exp(log_tau) ...: ...: log_prob = stats.norm(0, 5).logpdf(mu.values) ...: log_prob += stats.halfcauchy(0, 5).logpdf(tau.values) ...: log_prob += log_tau.values ...: log_prob += stats.norm(0, 1).logpdf(theta_tilde.values).sum(axis=-1) ...: ...: return xr.DataArray( ...: log_prob, ...: dims=["chain", "draw"], ...: coords={"chain": upars.chain, "draw": upars.draw} ...: ) ...: ...: def log_lik_i_upars(upars, i): ...: sigmas = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0] ...: theta_tilde_i = upars.sel(upars_dim=f"theta_t_{i}") ...: mu = upars.sel(upars_dim="mu") ...: tau = np.exp(upars.sel(upars_dim="log_tau")) ...: theta_i = mu + tau * theta_tilde_i ...: y_i = idata.observed_data.obs.values[i] ...: log_lik = stats.norm(theta_i.values, sigmas[i]).logpdf(y_i) ...: ...: return xr.DataArray( ...: log_lik, ...: dims=["chain", "draw"], ...: coords={"chain": upars.chain, "draw": upars.draw} ...: ) ...:
We can now apply moment matching using the split transformation and covariance matching. We can see that all Pareto \(k\) values are now below the threshold and the ELPD is slightly improved:
In [3]: loo_mm = loo_moment_match( ...: idata, ...: loo_orig, ...: upars=upars, ...: log_prob_upars_fn=log_prob_upars, ...: log_lik_i_upars_fn=log_lik_i_upars, ...: var_name="obs", ...: k_threshold=0.7, ...: split=True, ...: cov=False, ...: ) ...: loo_mm ...: Out[3]: Computed from 2000 posterior samples and 8 observations log-likelihood matrix. Estimate SE elpd_loo -30.60 1.37 p_loo 0.79 - ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.70] (good) 8 100.0% (0.70, 1] (bad) 0 0.0% (1, Inf) (very bad) 0 0.0%