arviz_stats.loo_metrics#
- arviz_stats.loo_metrics(data, kind='rmse', var_name=None, log_weights=None, round_to='2g')[source]#
Compute predictive metrics using the PSIS-LOO-CV method.
Currently supported metrics are mean absolute error, mean squared error and root mean squared error. For classification problems, accuracy and balanced accuracy are also supported.
The PSIS-LOO-CV method is described in [1] and [2].
- Parameters:
- data: DataTree or InferenceData
It should contain groups observed_data, posterior_predictive and log_likelihood.
- kind: str
The kind of metric to compute. Available options are:
‘mae’: mean absolute error.
‘mse’: mean squared error.
‘rmse’: root mean squared error. Default.
‘acc’: classification accuracy.
‘acc_balanced’: balanced classification accuracy.
- var_name: str, optional
The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.
- log_weights: DataArray or ELPDData, optional
Smoothed log weights. Can be either:
A DataArray with the same shape as the log likelihood data
An ELPDData object from a previous
arviz_stats.loo
call.
Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
- round_to: int or str, optional
If integer, number of decimal places to round the result. If string of the form ‘2g’ number of significant digits to round the result. Defaults to ‘2g’.
- Returns:
- estimate:
collections.namedtuple
A namedtuple with the mean of the computed metric and its standard error.
- estimate:
References
[1]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[2]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
Examples
Calculate predictive root mean squared error
In [1]: from arviz_stats import loo_metrics ...: from arviz_base import load_arviz_data ...: dt = load_arviz_data("radon") ...: loo_metrics(dt, kind="rmse") ...: Out[1]: rmse(mean=0.74, se=0.023)
Calculate accuracy of a logistic regression model
In [2]: dt = load_arviz_data("anes") ...: loo_metrics(dt, kind="acc") ...: Out[2]: acc(mean=0.82, se=0.02)