Source code for arviz_stats.loo.loo_approximate_posterior

"""Compute PSIS-LOO-CV for approximate posteriors."""

from arviz_base import rcParams

from arviz_stats.loo.helper_loo import (  # pylint: disable=cyclic-import
    _check_log_density,
    _compute_loo_results,
    _prepare_loo_inputs,
)


[docs] def loo_approximate_posterior(data, log_p, log_q, pointwise=None, var_name=None, log_jacobian=None): r"""Compute PSIS-LOO-CV for approximate posteriors. Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) for approximate posteriors (e.g., from variational inference). Requires log-densities of the target (log_p) and proposal (log_q) distributions. The PSIS-LOO-CV method is described in [1]_ and [2]_. The approximate posterior correction is computed using the method described in [3]_. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the log_likelihood group corresponding to samples drawn from the proposal distribution (q). log_p : ndarray or DataArray The (target) log-density evaluated at S samples from the target distribution (p). If ndarray, should be a vector of length S where S is the number of samples. If DataArray, should have dimensions matching the sample dimensions ("chain", "draw"). log_q : ndarray or DataArray The (proposal) log-density evaluated at S samples from the proposal distribution (q). If ndarray, should be a vector of length S where S is the number of samples. If DataArray, should have dimensions matching the sample dimensions ("chain", "draw"). pointwise : bool, optional If True, returns pointwise values. Defaults to rcParams["stats.ic_pointwise"]. var_name : str, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation. log_jacobian : DataArray, optional Log-Jacobian adjustment for variable transformations. Required when the model was fitted on transformed response data :math:`z = T(y)` but you want to compute ELPD on the original response scale :math:`y`. The value should be :math:`\log|\frac{dz}{dy}|` (the log absolute value of the derivative of the transformation). Must be a DataArray with dimensions matching the observation dimensions. Returns ------- ELPDData Object with the following attributes: - **elpd**: expected log pointwise predictive density - **se**: standard error of the elpd - **p**: effective number of parameters - **n_samples**: number of samples - **n_data_points**: number of data points - **warning**: True if the estimated shape parameter of Pareto distribution is greater than ``good_k``. - **elpd_i**: :class:`~xarray.DataArray` with the pointwise predictive accuracy, only if ``pointwise=True`` - **pareto_k**: array of Pareto shape values, only if ``pointwise=True`` - **good_k**: For a sample size S, the threshold is computed as ``min(1 - 1/log10(S), 0.7)`` - **approx_posterior**: True if approximate posterior was used. Examples -------- Calculate LOO for posterior approximations. The following example is intentionally minimal to demonstrate basic usage. The approximate posterior created below may not accurately represent the data and lead to less meaningful LOO results. Create dummy log-densities: .. ipython:: In [1]: import numpy as np ...: import xarray as xr ...: from arviz_stats import loo_approximate_posterior ...: from arviz_base import load_arviz_data, extract ...: ...: data = load_arviz_data("centered_eight") ...: log_lik = extract(data, group="log_likelihood", var_names="obs", combined=False) ...: rng = np.random.default_rng(214) ...: ...: values_p = rng.normal(loc=0, scale=1, size=(log_lik.chain.size, log_lik.draw.size)) ...: log_p = xr.DataArray( ...: values_p, ...: dims=["chain", "draw"], ...: coords={"chain": log_lik.chain, "draw": log_lik.draw} ...: ) ...: ...: values_q = rng.normal(loc=-1, scale=1, size=(log_lik.chain.size, log_lik.draw.size)) ...: log_q = xr.DataArray( ...: values_q, ...: dims=["chain", "draw"], ...: coords={"chain": log_lik.chain, "draw": log_lik.draw} ...: ) Calculate approximate pointwise LOO: .. ipython:: In [2]: loo_approx = loo_approximate_posterior( ...: data, ...: log_p=log_p, ...: log_q=log_q, ...: var_name="obs", ...: pointwise=True ...: ) ...: loo_approx See Also -------- loo : Standard PSIS-LOO-CV. loo_subsample : Sub-sampled PSIS-LOO-CV. compare : Compare models based on their ELPD. References ---------- .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 .. [3] Magnusson, M., Riis Andersen, M., Jonasson, J., & Vehtari, A. *Bayesian Leave-One-Out Cross-Validation for Large Data.* Proceedings of the 36th International Conference on Machine Learning, PMLR 97:4244–4253 (2019) https://proceedings.mlr.press/v97/magnusson19a.html arXiv preprint https://arxiv.org/abs/1904.10679 """ loo_inputs = _prepare_loo_inputs(data, var_name) pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise log_likelihood = loo_inputs.log_likelihood log_p = _check_log_density( log_p, "log_p", log_likelihood, loo_inputs.n_samples, loo_inputs.sample_dims ) log_q = _check_log_density( log_q, "log_q", log_likelihood, loo_inputs.n_samples, loo_inputs.sample_dims ) approx_correction = log_p - log_q # Handle underflow/overflow approx_correction = approx_correction - approx_correction.max() corrected_log_ratios = -log_likelihood.copy() corrected_log_ratios = corrected_log_ratios + approx_correction # Handle underflow/overflow log_ratio_max = corrected_log_ratios.max(dim=loo_inputs.sample_dims) corrected_log_ratios = corrected_log_ratios - log_ratio_max # ignore r_eff here, set to r_eff=1.0 log_weights, pareto_k = corrected_log_ratios.azstats.psislw( r_eff=1.0, dim=loo_inputs.sample_dims ) return _compute_loo_results( log_likelihood=loo_inputs.log_likelihood, var_name=loo_inputs.var_name, pointwise=pointwise, sample_dims=loo_inputs.sample_dims, n_samples=loo_inputs.n_samples, n_data_points=loo_inputs.n_data_points, log_weights=log_weights, pareto_k=pareto_k, approx_posterior=True, log_jacobian=log_jacobian, )