Source code for arviz_stats.loo.loo_pit

"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values."""

import numpy as np
import xarray as xr
from arviz_base import convert_to_datatree, extract
from xarray_einstats.stats import logsumexp

from arviz_stats.loo.helper_loo import _get_r_eff
from arviz_stats.utils import ELPDData, get_log_likelihood_dataset


[docs] def loo_pit( data, var_names=None, log_weights=None, ): r"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values. The LOO-PIT values are :math:`p(\tilde{y}_i \le y_i \mid y_{-i})`, where :math:`y_i` represents the observed data for index :math:`i` and :math:`\tilde y_i` represents the posterior predictive sample at index :math:`i`. Note that :math:`y_{-i}` indicates we have left out the :math:`i`-th observation. LOO-PIT values are computed using the PSIS-LOO-CV method described in [1]_ and [2]_. Parameters ---------- data : DataTree or InferenceData It should contain posterior, posterior_predictive and log_likelihood groups. var_names : str or list of str, optional Names of the variables to be used to compute the LOO-PIT values. If None, all variables are used. The function assumes that the observed and log_likelihood variables share the same names. log_weights: DataArray or ELPDData, optional Smoothed log weights. Can be either: - A DataArray with the same shape as ``y_pred`` - An ELPDData object from a previous :func:`arviz_stats.loo` call. Defaults to None. If not provided, it will be computed using the PSIS-LOO method. Returns ------- loo_pit: array or DataArray Value of the LOO-PIT at each observed data point. Examples -------- Calculate LOO-PIT values using as test quantity the observed values themselves. .. ipython:: In [1]: from arviz_stats import loo_pit ...: from arviz_base import load_arviz_data ...: dt = load_arviz_data("centered_eight") ...: loo_pit(dt) Calculate LOO-PIT values using as test quantity the square of the difference between each observation and `mu`. For this we create a new DataTree, copying the posterior and log_likelihood groups and creating new observed and posterior_predictive groups. .. ipython:: In [1]: from arviz_base import from_dict ...: new_dt = from_dict({"posterior": dt.posterior, ...: "log_likelihood": dt.log_likelihood, ...: "observed_data": { ...: "obs": (dt.observed_data.obs ...: - dt.posterior.mu.median(dim=("chain", "draw")))**2}, ...: "posterior_predictive": { ...: "obs": (dt.posterior_predictive.obs - dt.posterior.mu)**2}}) ...: loo_pit(new_dt) References ---------- .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ data = convert_to_datatree(data) rng = np.random.default_rng(214) if var_names is None: var_names = list(data.observed_data.data_vars.keys()) elif isinstance(var_names, str): var_names = [var_names] log_likelihood = get_log_likelihood_dataset(data, var_names=var_names) if log_weights is None: n_samples = log_likelihood.chain.size * log_likelihood.draw.size reff = _get_r_eff(data, n_samples) log_weights, _ = log_likelihood.azstats.psislw(r_eff=reff) if isinstance(log_weights, ELPDData): if log_weights.log_weights is None: raise ValueError("ELPDData object does not contain log_weights") log_weights = log_weights.log_weights posterior_predictive = extract( data, group="posterior_predictive", combined=False, var_names=var_names, keep_dataset=True, ) observed_data = extract( data, group="observed_data", combined=False, var_names=var_names, keep_dataset=True, ) sel_min = {} sel_sup = {} for var in var_names: pred = posterior_predictive[var] obs = observed_data[var] sel_min[var] = pred < obs sel_sup[var] = pred == obs sel_min = xr.Dataset(sel_min) sel_sup = xr.Dataset(sel_sup) pit = np.exp(logsumexp(log_weights.where(sel_min, -np.inf), dims=["chain", "draw"])) loo_pit_values = xr.Dataset(coords=observed_data.coords) for var in var_names: pit_lower = pit[var].values if sel_sup[var].any(): pit_sup_addition = np.exp( logsumexp(log_weights.where(sel_sup[var], -np.inf), dims=["chain", "draw"]) ) pit_upper = pit_lower + pit_sup_addition[var].values random_value = rng.uniform(pit_lower, pit_upper) loo_pit_values[var] = observed_data[var].copy(data=random_value) else: loo_pit_values[var] = observed_data[var].copy(data=pit_lower) return loo_pit_values