arviz_stats.metrics

Contents

arviz_stats.metrics#

arviz_stats.metrics(data, kind='rmse', var_name=None, sample_dims=None, round_to='2g')[source]#

Compute performace metrics.

Currently supported metrics are mean absolute error, mean squared error and root mean squared error. For classification problems, accuracy and balanced accuracy are also supported.

Parameters:
data: DataTree or InferenceData

It should contain groups observed_data and posterior_predictive.

kind: str

The kind of metric to compute. Available options are:

  • ‘mae’: mean absolute error.

  • ‘mse’: mean squared error.

  • ‘rmse’: root mean squared error. Default.

  • ‘acc’: classification accuracy.

  • ‘acc_balanced’: balanced classification accuracy.

var_name: str, optional

The name of the observed and predicted variable.

sample_dims: iterable of hashable, optional

Dimensions to be considered sample dimensions and are to be reduced. Default rcParams["data.sample_dims"].

round_to: int or str, optional

If integer, number of decimal places to round the result. If string of the form ‘2g’ number of significant digits to round the result. Defaults to ‘2g’. Use None to return raw numbers.

Returns:
estimate: collections.namedtuple

A namedtuple with the mean of the computed metric and its standard error.

Notes

The computation of the metrics is done by first reducing the posterior predictive samples, this is done to mirror the computation of the metrics by the arviz_stats.loo_metrics function, and hence make comparison easier to perform.

Examples

Calculate root mean squared error

In [1]: from arviz_stats import metrics
   ...: from arviz_base import load_arviz_data
   ...: dt = load_arviz_data("radon")
   ...: metrics(dt, kind="rmse")
   ...: 
Out[1]: rmse(mean=0.72, se=0.022)

Calculate accuracy of a logistic regression model

In [2]: dt = load_arviz_data("anes")
   ...: metrics(dt, kind="acc")
   ...: 
Out[2]: acc(mean=0.82, se=0.02)