arviz_stats.loo_moment_match

Contents

arviz_stats.loo_moment_match#

arviz_stats.loo_moment_match(data, loo_orig, log_prob_upars_fn, log_lik_i_upars_fn, upars=None, var_name=None, reff=None, max_iters=30, k_threshold=None, split=True, cov=False, pointwise=None)[source]#

Compute moment matching for problematic observations in PSIS-LOO-CV.

Adjusts the results of a previously computed Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) object by applying a moment matching algorithm to observations with high Pareto k diagnostic values. The moment matching algorithm iteratively adjusts the posterior draws in the unconstrained parameter space to better approximate the leave-one-out posterior.

The moment matching algorithm is described in [1] and the PSIS-LOO-CV method is described in [2] and [3].

Parameters:
dataxarray.DataTree or InferenceData

Input data. It should contain the posterior and the log_likelihood groups.

loo_origELPDData

An existing ELPDData object from a previous loo result. Must contain pointwise Pareto k values (pointwise=True must have been used).

log_prob_upars_fnCallable[[xarray.DataArray], xarray.DataArray]

A function that takes the unconstrained parameter draws and returns a DataArray containing the log probability density of the full posterior distribution evaluated at each unconstrained parameter draw. The returned DataArray must have dimensions chain, draw.

log_lik_i_upars_fnCallable[[xarray.DataArray, int], xarray.DataArray]

A function that takes the unconstrained parameter draws and the integer index i of the left-out observation. It should return a DataArray containing the log-likelihood of the left-out observation i evaluated at each unconstrained parameter draw. The returned DataArray must have dimensions chain, draw.

uparsxarray.DataArray, optional

Posterior draws transformed to the unconstrained parameter space. Must have chain and draw dimensions, plus one additional dimension containing all parameters. Parameter names can be provided as coordinate values on this dimension. If not provided, will attempt to use the unconstrained_posterior group from the input data if available.

var_namestr, optional

The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.

reff: float, optional

Relative MCMC efficiency, ess / n i.e. number of effective samples divided by the number of actual samples. Computed from trace by default.

max_itersint, default 30

Maximum number of moment matching iterations for each problematic observation.

k_thresholdfloat, optional

Threshold value for Pareto k values above which moment matching is applied. Defaults to \(\min(1 - 1/\log_{10}(S), 0.7)\), where S is the number of samples.

splitbool, default True

If True, only transform half of the draws and use multiple importance sampling to combine them with untransformed draws.

covbool, default False

If True, match the covariance structure during the transformation, in addition to the mean and marginal variances. Ignored if split=False.

pointwise: bool, optional

If True, the pointwise predictive accuracy will be returned. Defaults to rcParams["stats.ic_pointwise"]. Moment matching always requires pointwise data from loo_orig. This argument controls whether the returned object includes pointwise data.

Returns:
ELPDData

Object with the following attributes:

  • elpd: expected log pointwise predictive density

  • se: standard error of the elpd

  • p: effective number of parameters

  • n_samples: number of samples

  • n_data_points: number of data points

  • warning: True if the estimated shape parameter of Pareto distribution is greater than good_k.

  • elp_i: DataArray with the pointwise predictive accuracy, only if pointwise=True

  • pareto_k: array of Pareto shape values, only if pointwise=True

  • good_k: For a sample size S, the threshold is computed as min(1 - 1/log10(S), 0.7)

  • approx_posterior: True if approximate posterior was used.

See also

loo

Standard PSIS-LOO-CV.

reloo

Exact re-fitting for problematic observations.

Notes

The moment matching algorithm considers three affine transformations of the posterior draws. For a specific draw \(\theta^{(s)}\), a generic affine transformation includes a square matrix \(\mathbf{A}\) representing a linear map and a vector \(\mathbf{b}\) representing a translation such that

\[T : \theta^{(s)} \mapsto \mathbf{A}\theta^{(s)} + \mathbf{b} =: \theta^{*{(s)}}.\]

The first transformation, \(T_1\), is a translation that matches the mean of the sample to its importance weighted mean given by

\[\mathbf{\theta^{*{(s)}}} = T_1(\mathbf{\theta^{(s)}}) = \mathbf{\theta^{(s)}} - \bar{\theta} + \bar{\theta}_w,\]

where \(\bar{\theta}\) is the mean of the sample and \(\bar{\theta}_w\) is the importance weighted mean of the sample. The second transformation, \(T_2\), is a scaling that matches the marginal variances in addition to the means given by

\[\mathbf{\theta^{*{(s)}}} = T_2(\mathbf{\theta^{(s)}}) = \mathbf{v}^{1/2}_w \circ \mathbf{v}^{-1/2} \circ (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w,\]

where \(\mathbf{v}\) and \(\mathbf{v}_w\) are the sample and weighted variances, and \(\circ\) denotes the pointwise product of the elements of two vectors. The third transformation, \(T_3\), is a covariance transformation that matches the covariance matrix of the sample to its importance weighted covariance matrix given by

\[\mathbf{\theta^{*{(s)}}} = T_3(\mathbf{\theta^{(s)}}) = \mathbf{L}_w \mathbf{L}^{-1} (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w,\]

where \(\mathbf{L}\) and \(\mathbf{L}_w\) are the Cholesky decompositions of the covariance matrix and the weighted covariance matrix, respectively, e.g.,

\[\mathbf{LL}^T = \mathbf{\Sigma} = \frac{1}{S} \sum_{s=1}^S (\mathbf{\theta^{(s)}} - \bar{\theta}) (\mathbf{\theta^{(s)}} - \bar{\theta})^T\]

and

\[\mathbf{L}_w \mathbf{L}_w^T = \mathbf{\Sigma}_w = \frac{\frac{1}{S} \sum_{s=1}^S w^{(s)} (\mathbf{\theta^{(s)}} - \bar{\theta}_w) (\mathbf{\theta^{(s)}} - \bar{\theta}_w)^T}{\sum_{s=1}^S w^{(s)}}.\]

We iterate on \(T_1\) repeatedly and move onto \(T_2\) and \(T_3\) only if \(T_1\) fails to yield a Pareto-k statistic below the threshold.

References

[1]

Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). Implicitly Adaptive Importance Sampling. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850.

[2]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[3]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Moment matching can improve PSIS-LOO-CV estimates for observations with high Pareto k values without having to refit the model for each problematic observation. We will use the non-centered eight schools data which has 1 problematic observation:

In [1]: import arviz_base as az
   ...: from arviz_stats import loo, loo_moment_match
   ...: import numpy as np
   ...: import xarray as xr
   ...: from scipy import stats
   ...: 
   ...: idata = az.load_arviz_data("non_centered_eight")
   ...: loo_orig = loo(idata, pointwise=True, var_name="obs")
   ...: loo_orig
   ...: 
Out[1]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

         Estimate       SE
elpd_loo   -30.72     1.33
p_loo        0.90        -

There has been a warning during the calculation. Please check the results.
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)        7   87.5%
   (0.70, 1]   (bad)         1   12.5%
    (1, Inf)   (very bad)    0    0.0%

For moment matching, we need the unconstrained parameters and two functions for the log probability and pointwise log-likelihood computations:

In [2]: posterior = idata.posterior
   ...: theta_t = posterior.theta_t.values
   ...: mu = posterior.mu.values[:, :, np.newaxis]
   ...: log_tau = np.log(posterior.tau.values)[:, :, np.newaxis]
   ...: 
   ...: upars = np.concatenate([theta_t, mu, log_tau], axis=2)
   ...: param_names = [f"theta_t_{i}" for i in range(8)] + ["mu", "log_tau"]
   ...: 
   ...: upars = xr.DataArray(
   ...:     upars,
   ...:     dims=["chain", "draw", "upars_dim"],
   ...:     coords={
   ...:         "chain": posterior.chain,
   ...:         "draw": posterior.draw,
   ...:         "upars_dim": param_names
   ...:     }
   ...: )
   ...: 
   ...: def log_prob_upars(upars):
   ...:     theta_tilde = upars.sel(upars_dim=[f"theta_t_{i}" for i in range(8)])
   ...:     mu = upars.sel(upars_dim="mu")
   ...:     log_tau = upars.sel(upars_dim="log_tau")
   ...:     tau = np.exp(log_tau)
   ...: 
   ...:     log_prob = stats.norm(0, 5).logpdf(mu.values)
   ...:     log_prob += stats.halfcauchy(0, 5).logpdf(tau.values)
   ...:     log_prob += log_tau.values
   ...:     log_prob += stats.norm(0, 1).logpdf(theta_tilde.values).sum(axis=-1)
   ...: 
   ...:     return xr.DataArray(
   ...:         log_prob,
   ...:         dims=["chain", "draw"],
   ...:         coords={"chain": upars.chain, "draw": upars.draw}
   ...:     )
   ...: 
   ...: def log_lik_i_upars(upars, i):
   ...:     sigmas = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
   ...:     theta_tilde_i = upars.sel(upars_dim=f"theta_t_{i}")
   ...:     mu = upars.sel(upars_dim="mu")
   ...:     tau = np.exp(upars.sel(upars_dim="log_tau"))
   ...:     theta_i = mu + tau * theta_tilde_i
   ...:     y_i = idata.observed_data.obs.values[i]
   ...:     log_lik = stats.norm(theta_i.values, sigmas[i]).logpdf(y_i)
   ...: 
   ...:     return xr.DataArray(
   ...:         log_lik,
   ...:         dims=["chain", "draw"],
   ...:         coords={"chain": upars.chain, "draw": upars.draw}
   ...:     )
   ...: 

We can now apply moment matching using the split transformation and covariance matching. We can see that all Pareto \(k\) values are now below the threshold and the ELPD is slightly improved:

In [3]: loo_mm = loo_moment_match(
   ...:     idata,
   ...:     loo_orig,
   ...:     upars=upars,
   ...:     log_prob_upars_fn=log_prob_upars,
   ...:     log_lik_i_upars_fn=log_lik_i_upars,
   ...:     var_name="obs",
   ...:     k_threshold=0.7,
   ...:     split=True,
   ...:     cov=False,
   ...: )
   ...: loo_mm
   ...: 
Out[3]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

         Estimate       SE
elpd_loo   -30.60     1.37
p_loo        0.79        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)        8  100.0%
   (0.70, 1]   (bad)         0    0.0%
    (1, Inf)   (very bad)    0    0.0%