arviz_stats.loo_expectations

Contents

arviz_stats.loo_expectations#

arviz_stats.loo_expectations(data, var_name=None, log_weights=None, kind='mean', probs=None)[source]#

Compute weighted expectations using the PSIS-LOO-CV method.

The expectations assume that the PSIS approximation is working well. The PSIS-LOO-CV method is described in [1] and [2].

Parameters:
data: DataTree or InferenceData

It should contain the groups posterior_predictive and log_likelihood.

var_name: str, optional

The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.

log_weightsxarray.DataArray or ELPDData, optional

Smoothed log weights. Can be either:

  • A DataArray with the same shape as the log likelihood data

  • An ELPDData object from a previous arviz_stats.loo call.

Defaults to None. If not provided, it will be computed using the PSIS-LOO method.

kind: str, optional

The kind of expectation to compute. Available options are:

  • ‘mean’: the mean of the posterior predictive distribution. Default.

  • ‘median’: the median of the posterior predictive distribution.

  • ‘var’: the variance of the posterior predictive distribution.

  • ‘sd’: the standard deviation of the posterior predictive distribution.

  • ‘quantile’: the quantile of the posterior predictive distribution.

probs: float or list of float, optional

The quantile(s) to compute when kind is ‘quantile’.

Returns:
loo_expecxarray.DataArray

The weighted expectations.

khatxarray.DataArray

Function-specific Pareto k-hat diagnostics for each observation.

References

[1]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[2]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Calculate predictive 0.25 and 0.75 quantiles and the function-specific Pareto k-hat diagnostics

In [1]: from arviz_stats import loo_expectations
   ...: from arviz_base import load_arviz_data
   ...: dt = load_arviz_data("radon")
   ...: loo_expec, khat = loo_expectations(dt, kind="quantile", probs=[0.25, 0.75])
   ...: loo_expec
   ...: 
Out[1]: 
<xarray.DataArray 'y' (quantile: 2, obs_id: 919)> Size: 15kB
array([[-0.212921  ,  0.51393151,  0.50130407, ...,  1.00688331,
         1.19475673,  1.220047  ],
       [ 0.82123096,  1.49772976,  1.49586184, ...,  1.96597453,
         2.22125176,  2.23358734]], shape=(2, 919))
Coordinates:
  * obs_id    (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
  * quantile  (quantile) float64 16B 0.25 0.75
In [2]: khat
Out[2]: 
<xarray.DataArray 'y' (obs_id: 919)> Size: 7kB
array([ 5.58019381e-02,  2.17065243e-01,  1.51669728e-02,  4.23351189e-01,
        1.58790427e-01,  1.12689765e-01, -8.54932638e-02, -1.31580525e-02,
       -2.17511696e-02, -6.55677937e-02, -6.55677937e-02, -3.85099044e-02,
       -7.02219730e-02,  1.66029117e-02, -4.02703651e-02,  1.68766292e-01,
       -3.85099044e-02, -5.03002946e-02,  2.07949057e-01,  1.90443341e-01,
       -2.71199694e-02,  1.89252503e-01,  2.29352126e-01, -6.16692266e-02,
        2.22332121e-01,  1.58790427e-01,  2.25931523e-01,  2.37220381e-01,
        2.22162842e-01,  1.83360385e-01,  1.68766292e-01,  1.69012664e-01,
        1.23084739e-01,  1.98847482e-01, -3.40911332e-02,  1.57853655e-01,
        1.89252503e-01,  1.42561407e-01, -3.09466739e-02,  1.50971581e-01,
       -7.02219730e-02,  1.90443341e-01,  1.50971581e-01,  2.13051622e-01,
        1.29935004e-01,  1.00731289e-01, -3.72657428e-02, -7.67812513e-02,
       -1.62754981e-02, -4.77234266e-02,  1.50971581e-01, -7.67812513e-02,
       -4.77234266e-02, -7.02219730e-02, -6.16692266e-02, -2.71199694e-02,
        5.38740990e-02,  3.54929339e-01,  2.11252395e-01, -9.02117798e-02,
       -5.76982262e-02, -7.51876697e-02,  6.80144671e-02,  3.88992642e-01,
        5.58021714e-02,  5.11760831e-02,  1.98269690e-01,  1.57254942e-01,
        5.72824277e-02,  1.57718577e-01,  1.60494969e-01,  1.85930892e-01,
        2.34208620e-01,  1.51333890e-01, -1.12040465e-01, -1.33096328e-01,
       -3.52479114e-02, -1.11867987e-01, -9.22290180e-04, -1.65139922e-01,
...
        2.19333333e-01,  2.12890852e-01,  1.44202177e-01, -2.02705227e-02,
        9.79222241e-02,  2.14603355e-01,  5.02075472e-02,  1.28932069e-01,
        9.79222241e-02,  8.66157234e-02,  1.06065708e-02, -1.07490130e-01,
       -5.08264635e-02,  1.23775505e-01, -1.07490130e-01,  4.60357899e-02,
        1.12260448e-01,  3.98190109e-02,  4.69209174e-02, -9.98828035e-02,
        5.69970197e-02, -7.03502902e-02,  1.96562233e-01,  7.39301861e-02,
        1.64005653e-01,  1.06065708e-02,  9.27011855e-02, -7.53203830e-02,
        1.08314719e-01,  5.67120538e-02, -7.05157214e-02,  2.15861644e-01,
       -5.08264635e-02, -1.37513336e-02,  4.69209174e-02, -1.25223775e-01,
        2.33529888e-01,  8.39076753e-02,  1.82327468e-02, -7.53203830e-02,
        2.68337252e-02,  8.21658060e-02,  2.68337252e-02,  9.85446164e-02,
        9.27011855e-02,  2.00306137e-01,  1.18038597e-01,  1.97633878e-01,
        1.61925577e-01,  1.97250489e-01,  2.16266755e-01,  9.51183537e-02,
        7.98328460e-02, -8.96610880e-03,  7.65710309e-02,  1.18913209e-01,
        1.78532098e-01,  2.99531088e-01,  1.18913209e-01, -2.61446568e-02,
       -1.11961304e-02,  5.94893994e-02, -2.17032191e-03,  2.69609056e-01,
        7.19859624e-02, -7.85260250e-03,  1.29805286e-01,  1.79499690e-01,
        4.32856349e-02, -3.20726229e-02,  1.60560560e-01,  1.73804164e-01,
        4.22184757e-02,  1.44483291e-01,  9.74179760e-02,  1.22538673e-01,
       -4.69686709e-03,  3.13850332e-02,  3.92762729e-02])
Coordinates:
  * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918