arviz_stats.loo_kfold

Contents

arviz_stats.loo_kfold#

arviz_stats.loo_kfold(data, wrapper, pointwise=None, var_name=None, k=10, folds=None, stratify_by=None, group_by=None, save_fits=False)[source]#

Perform exact K-fold cross-validation.

K-fold cross-validation evaluates model predictive accuracy by partitioning the data into K complementary subsets (folds), then iteratively refitting the model K times, each time holding out one fold as a test set and training on the remaining K-1 folds.

This method provides an unbiased estimate of model performance by ensuring each observation is used exactly once for testing. Unlike PSIS-LOO-CV (Pareto-smoothed importance sampling leave-one-out cross-validation), which approximates cross-validation efficiently, K-fold requires actual model refitting but yields exact results.

Parameters:
dataxarray.DataTree or InferenceData

Input data containing the posterior and log_likelihood groups from the full model fit.

wrapperSamplingWrapper

An instance of SamplingWrapper class handling model refitting. The wrapper must implement the following methods: sel_observations, sample, get_inference_data, and log_likelihood__i.

pointwisebool, optional

If True, return pointwise estimates. Defaults to rcParams["stats.ic_pointwise"].

var_namestr, optional

The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for computation.

kint, default=10

The number of folds for cross-validation. The data will be partitioned into k subsets of equal (or approximately equal) size.

foldsarray or xarray.DataArray, optional

Manual fold assignments (1 to k) for each observation. For example, [1,1,2,2,3,3,4,4] assigns first two obs to fold 1, next two to fold 2, etc. If not provided, creates k random folds of equal size. Cannot be used together with stratify_by and group_by.

stratify_byarray or xarray.DataArray, optional

Maintains class proportions across folds. For example, [0,0,1,1,0,0,1,1] ensures each fold has 50% class 0 and 50% class 1. Cannot be used together with folds and group_by.

group_byarray or xarray.DataArray, optional

Grouping variable to keep related observations together in the same fold. For example, [1,1,2,2,3,3] keeps all obs from group 1 in one fold, group 2 in another, etc. Useful for repeated measures or clustered data. Cannot be used together with folds and stratify_by.

save_fitsbool, default=False

If True, store the fitted models and fold indices in the returned object.

Returns:
ELPDData

Object with the following attributes:

  • elpd: expected log pointwise predictive density

  • se: standard error of the elpd

  • p: effective number of parameters

  • n_samples: number of samples per fold

  • n_data_points: number of data points

  • warning: True if any issues occurred during fitting

  • elpd_i: pointwise predictive accuracy (if pointwise=True)

  • p_kfold_i: pointwise effective number of parameters (if pointwise=True)

  • pareto_k: None (not applicable for k-fold)

  • scale: “log”

Additional attributes when save_fits=True:

  • fold_fits: Dictionary containing fitted models for each fold

  • fold_indices: Dictionary containing test indices for each fold

See also

loo

Pareto-smoothed importance sampling LOO-CV

SamplingWrapper

Base class for implementing sampling wrappers

Notes

When K equals the number of observations, this becomes exact leave-one-out cross-validation. Note that arviz_stats.loo provides a much more efficient approximation for that case and is recommended.

References

[1]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[2]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Unlike PSIS-LOO (which approximates LOO-CV), k-fold cross-validation refits the model k times. So we need to tell loo_kfold how to refit the model.

This is done by creating an instance of the SamplingWrapper class that implements four key methods: sel_observations, sample, get_inference_data, and log_likelihood__i.

In [1]: import numpy as np
   ...: import xarray as xr
   ...: from scipy import stats
   ...: from arviz_base import load_arviz_data, from_dict
   ...: from arviz_stats import loo_kfold
   ...: from arviz_stats.loo import SamplingWrapper
   ...: 
   ...: class CenteredEightWrapper(SamplingWrapper):
   ...:     def __init__(self, idata):
   ...:         super().__init__(model=None, idata_orig=idata)
   ...:         self.y_obs = idata.observed_data["obs"].values
   ...:         self.sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
   ...: 
   ...:     def sel_observations(self, idx):
   ...:         all_idx = np.arange(len(self.y_obs))
   ...:         train_idx = np.setdiff1d(all_idx, idx)
   ...: 
   ...:         train_data = {
   ...:             "y": self.y_obs[train_idx],
   ...:             "sigma": self.sigma[train_idx],
   ...:             "indices": train_idx
   ...:         }
   ...:         test_data = {
   ...:             "y": self.y_obs[idx],
   ...:             "sigma": self.sigma[idx],
   ...:             "indices": idx
   ...:         }
   ...:         return train_data, test_data
   ...: 
   ...:     def sample(self, modified_observed_data):
   ...:         # (Simplified version where we normally would use the actual sampler)
   ...:         train_y = modified_observed_data["y"]
   ...:         n = 1000
   ...:         mu = np.random.normal(train_y.mean(), 5, n)
   ...:         tau = np.abs(np.random.normal(10, 2, n))
   ...:         return {"mu": mu, "tau": tau}
   ...: 
   ...:     def get_inference_data(self, fitted_model):
   ...:         posterior = {
   ...:             "mu": fitted_model["mu"].reshape(1, -1),
   ...:             "tau": fitted_model["tau"].reshape(1, -1)
   ...:         }
   ...:         return from_dict({"posterior": posterior})
   ...: 
   ...:     def log_likelihood__i(self, excluded_obs, idata__i):
   ...:         test_y = excluded_obs["y"]
   ...:         test_sigma = excluded_obs["sigma"]
   ...:         mu = idata__i.posterior["mu"].values.flatten()
   ...:         tau = idata__i.posterior["tau"].values.flatten()
   ...: 
   ...:         var_total = tau[:, np.newaxis] ** 2 + test_sigma**2
   ...:         log_lik = stats.norm.logpdf(
   ...:             test_y, loc=mu[:, np.newaxis], scale=np.sqrt(var_total)
   ...:         )
   ...: 
   ...:         dims = ["chain", "school", "draw"]
   ...:         coords = {"school": excluded_obs["indices"]}
   ...:         return xr.DataArray(
   ...:             log_lik.T[np.newaxis, :, :], dims=dims, coords=coords
   ...:         )
   ...: 

Now let’s run k-fold cross-validation. With k=4, we’ll refit the model 4 times, each time leaving out 2 schools for testing:

In [2]: data = load_arviz_data("centered_eight")
   ...: wrapper = CenteredEightWrapper(data)
   ...: kfold_results = loo_kfold(data, wrapper, k=4, pointwise=True)
   ...: kfold_results
   ...: 
Out[2]: 
Computed from 4-fold cross validation.

           Estimate       SE
elpd_kfold   -31.94     0.88
p_kfold        2.11        -

Sometimes we want more control over how the data is split. For instance, if you have imbalanced groups, stratified k-fold ensures each fold has a similar distribution:

In [3]: strata = (data.observed_data["obs"] > 5).astype(int)
   ...: kfold_strat = loo_kfold(data, wrapper, k=4, stratify_by=strata)
   ...: kfold_strat
   ...: 
Out[3]: 
Computed from 4-fold cross validation.

           Estimate       SE
elpd_kfold   -31.40     0.74
p_kfold        1.56        -

Moreover, sometimes we want to group observations together. For instance, if we have repeated measurements from the same subject, we can group by subject:

In [4]: groups = xr.DataArray([1, 1, 2, 2, 3, 3, 4, 4], dims="school")
   ...: kfold_group = loo_kfold(data, wrapper, k=4, group_by=groups)
   ...: kfold_group
   ...: 
Out[4]: 
Computed from 4-fold cross validation.

           Estimate       SE
elpd_kfold   -31.83     0.75
p_kfold        1.99        -