arviz_stats.loo_expectations#
- arviz_stats.loo_expectations(data, var_name=None, log_weights=None, kind='mean', probs=None)[source]#
Compute weighted expectations using the PSIS-LOO-CV method.
The expectations assume that the PSIS approximation is working well. The PSIS-LOO-CV method is described in [1] and [2].
- Parameters:
- data: DataTree or InferenceData
It should contain the groups posterior_predictive and log_likelihood.
- var_name: str, optional
The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.
- log_weights
xarray.DataArray
orELPDData
, optional Smoothed log weights. Can be either:
A DataArray with the same shape as the log likelihood data
An ELPDData object from a previous
arviz_stats.loo
call.
Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
- kind: str, optional
The kind of expectation to compute. Available options are:
‘mean’: the mean of the posterior predictive distribution. Default.
‘median’: the median of the posterior predictive distribution.
‘var’: the variance of the posterior predictive distribution.
‘sd’: the standard deviation of the posterior predictive distribution.
‘quantile’: the quantile of the posterior predictive distribution.
- probs: float or list of float, optional
The quantile(s) to compute when kind is ‘quantile’.
- Returns:
- loo_expec
xarray.DataArray
The weighted expectations.
- khat
xarray.DataArray
Function-specific Pareto k-hat diagnostics for each observation.
- loo_expec
References
[1]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[2]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
Examples
Calculate predictive 0.25 and 0.75 quantiles and the function-specific Pareto k-hat diagnostics
In [1]: from arviz_stats import loo_expectations ...: from arviz_base import load_arviz_data ...: dt = load_arviz_data("radon") ...: loo_expec, khat = loo_expectations(dt, kind="quantile", probs=[0.25, 0.75]) ...: loo_expec ...: Out[1]: <xarray.DataArray 'y' (quantile: 2, obs_id: 919)> Size: 15kB array([[-0.212921 , 0.51393151, 0.50130407, ..., 1.00688331, 1.19475673, 1.220047 ], [ 0.82123096, 1.49772976, 1.49586184, ..., 1.96597453, 2.22125176, 2.23358734]], shape=(2, 919)) Coordinates: * obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918 * quantile (quantile) float64 16B 0.25 0.75
In [2]: khat Out[2]: <xarray.DataArray 'y' (obs_id: 919)> Size: 7kB array([ 5.58019381e-02, 2.17065243e-01, 1.51669728e-02, 4.23351189e-01, 1.58790427e-01, 1.12689765e-01, -8.54932638e-02, -1.31580525e-02, -2.17511696e-02, -6.55677937e-02, -6.55677937e-02, -3.85099044e-02, -7.02219730e-02, 1.66029117e-02, -4.02703651e-02, 1.68766292e-01, -3.85099044e-02, -5.03002946e-02, 2.07949057e-01, 1.90443341e-01, -2.71199694e-02, 1.89252503e-01, 2.29352126e-01, -6.16692266e-02, 2.22332121e-01, 1.58790427e-01, 2.25931523e-01, 2.37220381e-01, 2.22162842e-01, 1.83360385e-01, 1.68766292e-01, 1.69012664e-01, 1.23084739e-01, 1.98847482e-01, -3.40911332e-02, 1.57853655e-01, 1.89252503e-01, 1.42561407e-01, -3.09466739e-02, 1.50971581e-01, -7.02219730e-02, 1.90443341e-01, 1.50971581e-01, 2.13051622e-01, 1.29935004e-01, 1.00731289e-01, -3.72657428e-02, -7.67812513e-02, -1.62754981e-02, -4.77234266e-02, 1.50971581e-01, -7.67812513e-02, -4.77234266e-02, -7.02219730e-02, -6.16692266e-02, -2.71199694e-02, 5.38740990e-02, 3.54929339e-01, 2.11252395e-01, -9.02117798e-02, -5.76982262e-02, -7.51876697e-02, 6.80144671e-02, 3.88992642e-01, 5.58021714e-02, 5.11760831e-02, 1.98269690e-01, 1.57254942e-01, 5.72824277e-02, 1.57718577e-01, 1.60494969e-01, 1.85930892e-01, 2.34208620e-01, 1.51333890e-01, -1.12040465e-01, -1.33096328e-01, -3.52479114e-02, -1.11867987e-01, -9.22290180e-04, -1.65139922e-01, ... 2.19333333e-01, 2.12890852e-01, 1.44202177e-01, -2.02705227e-02, 9.79222241e-02, 2.14603355e-01, 5.02075472e-02, 1.28932069e-01, 9.79222241e-02, 8.66157234e-02, 1.06065708e-02, -1.07490130e-01, -5.08264635e-02, 1.23775505e-01, -1.07490130e-01, 4.60357899e-02, 1.12260448e-01, 3.98190109e-02, 4.69209174e-02, -9.98828035e-02, 5.69970197e-02, -7.03502902e-02, 1.96562233e-01, 7.39301861e-02, 1.64005653e-01, 1.06065708e-02, 9.27011855e-02, -7.53203830e-02, 1.08314719e-01, 5.67120538e-02, -7.05157214e-02, 2.15861644e-01, -5.08264635e-02, -1.37513336e-02, 4.69209174e-02, -1.25223775e-01, 2.33529888e-01, 8.39076753e-02, 1.82327468e-02, -7.53203830e-02, 2.68337252e-02, 8.21658060e-02, 2.68337252e-02, 9.85446164e-02, 9.27011855e-02, 2.00306137e-01, 1.18038597e-01, 1.97633878e-01, 1.61925577e-01, 1.97250489e-01, 2.16266755e-01, 9.51183537e-02, 7.98328460e-02, -8.96610880e-03, 7.65710309e-02, 1.18913209e-01, 1.78532098e-01, 2.99531088e-01, 1.18913209e-01, -2.61446568e-02, -1.11961304e-02, 5.94893994e-02, -2.17032191e-03, 2.69609056e-01, 7.19859624e-02, -7.85260250e-03, 1.29805286e-01, 1.79499690e-01, 4.32856349e-02, -3.20726229e-02, 1.60560560e-01, 1.73804164e-01, 4.22184757e-02, 1.44483291e-01, 9.74179760e-02, 1.22538673e-01, -4.69686709e-03, 3.13850332e-02, 3.92762729e-02]) Coordinates: * obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918