arviz_stats.loo_kfold#
- arviz_stats.loo_kfold(data, wrapper, pointwise=None, var_name=None, k=10, folds=None, stratify_by=None, group_by=None, save_fits=False)[source]#
Perform exact K-fold cross-validation.
K-fold cross-validation evaluates model predictive accuracy by partitioning the data into K complementary subsets (folds), then iteratively refitting the model K times, each time holding out one fold as a test set and training on the remaining K-1 folds.
This method provides an unbiased estimate of model performance by ensuring each observation is used exactly once for testing. Unlike PSIS-LOO-CV (Pareto-smoothed importance sampling leave-one-out cross-validation), which approximates cross-validation efficiently, K-fold requires actual model refitting but yields exact results.
- Parameters:
- data
xarray.DataTree
orInferenceData
Input data containing the posterior and log_likelihood groups from the full model fit.
- wrapper
SamplingWrapper
An instance of SamplingWrapper class handling model refitting. The wrapper must implement the following methods: sel_observations, sample, get_inference_data, and log_likelihood__i.
- pointwisebool, optional
If True, return pointwise estimates. Defaults to
rcParams["stats.ic_pointwise"]
.- var_name
str
, optional The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for computation.
- k
int
, default=10 The number of folds for cross-validation. The data will be partitioned into k subsets of equal (or approximately equal) size.
- folds
array
orxarray.DataArray
, optional Manual fold assignments (1 to k) for each observation. For example, [1,1,2,2,3,3,4,4] assigns first two obs to fold 1, next two to fold 2, etc. If not provided, creates k random folds of equal size. Cannot be used together with stratify_by and group_by.
- stratify_by
array
orxarray.DataArray
, optional Maintains class proportions across folds. For example, [0,0,1,1,0,0,1,1] ensures each fold has 50% class 0 and 50% class 1. Cannot be used together with folds and group_by.
- group_by
array
orxarray.DataArray
, optional Grouping variable to keep related observations together in the same fold. For example, [1,1,2,2,3,3] keeps all obs from group 1 in one fold, group 2 in another, etc. Useful for repeated measures or clustered data. Cannot be used together with folds and stratify_by.
- save_fitsbool, default=False
If True, store the fitted models and fold indices in the returned object.
- data
- Returns:
ELPDData
Object with the following attributes:
elpd: expected log pointwise predictive density
se: standard error of the elpd
p: effective number of parameters
n_samples: number of samples per fold
n_data_points: number of data points
warning: True if any issues occurred during fitting
elpd_i: pointwise predictive accuracy (if
pointwise=True
)p_kfold_i: pointwise effective number of parameters (if
pointwise=True
)pareto_k: None (not applicable for k-fold)
scale: “log”
Additional attributes when
save_fits=True
:fold_fits: Dictionary containing fitted models for each fold
fold_indices: Dictionary containing test indices for each fold
See also
loo
Pareto-smoothed importance sampling LOO-CV
SamplingWrapper
Base class for implementing sampling wrappers
Notes
When K equals the number of observations, this becomes exact leave-one-out cross-validation. Note that
arviz_stats.loo
provides a much more efficient approximation for that case and is recommended.References
[1]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[2]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
Examples
Unlike PSIS-LOO (which approximates LOO-CV), k-fold cross-validation refits the model k times. So we need to tell
loo_kfold
how to refit the model.This is done by creating an instance of the
SamplingWrapper
class that implements four key methods:sel_observations
,sample
,get_inference_data
, andlog_likelihood__i
.In [1]: import numpy as np ...: import xarray as xr ...: from scipy import stats ...: from arviz_base import load_arviz_data, from_dict ...: from arviz_stats import loo_kfold ...: from arviz_stats.loo import SamplingWrapper ...: ...: class CenteredEightWrapper(SamplingWrapper): ...: def __init__(self, idata): ...: super().__init__(model=None, idata_orig=idata) ...: self.y_obs = idata.observed_data["obs"].values ...: self.sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) ...: ...: def sel_observations(self, idx): ...: all_idx = np.arange(len(self.y_obs)) ...: train_idx = np.setdiff1d(all_idx, idx) ...: ...: train_data = { ...: "y": self.y_obs[train_idx], ...: "sigma": self.sigma[train_idx], ...: "indices": train_idx ...: } ...: test_data = { ...: "y": self.y_obs[idx], ...: "sigma": self.sigma[idx], ...: "indices": idx ...: } ...: return train_data, test_data ...: ...: def sample(self, modified_observed_data): ...: # (Simplified version where we normally would use the actual sampler) ...: train_y = modified_observed_data["y"] ...: n = 1000 ...: mu = np.random.normal(train_y.mean(), 5, n) ...: tau = np.abs(np.random.normal(10, 2, n)) ...: return {"mu": mu, "tau": tau} ...: ...: def get_inference_data(self, fitted_model): ...: posterior = { ...: "mu": fitted_model["mu"].reshape(1, -1), ...: "tau": fitted_model["tau"].reshape(1, -1) ...: } ...: return from_dict({"posterior": posterior}) ...: ...: def log_likelihood__i(self, excluded_obs, idata__i): ...: test_y = excluded_obs["y"] ...: test_sigma = excluded_obs["sigma"] ...: mu = idata__i.posterior["mu"].values.flatten() ...: tau = idata__i.posterior["tau"].values.flatten() ...: ...: var_total = tau[:, np.newaxis] ** 2 + test_sigma**2 ...: log_lik = stats.norm.logpdf( ...: test_y, loc=mu[:, np.newaxis], scale=np.sqrt(var_total) ...: ) ...: ...: dims = ["chain", "school", "draw"] ...: coords = {"school": excluded_obs["indices"]} ...: return xr.DataArray( ...: log_lik.T[np.newaxis, :, :], dims=dims, coords=coords ...: ) ...:
Now let’s run k-fold cross-validation. With k=4, we’ll refit the model 4 times, each time leaving out 2 schools for testing:
In [2]: data = load_arviz_data("centered_eight") ...: wrapper = CenteredEightWrapper(data) ...: kfold_results = loo_kfold(data, wrapper, k=4, pointwise=True) ...: kfold_results ...: Out[2]: Computed from 4-fold cross validation. Estimate SE elpd_kfold -31.94 0.88 p_kfold 2.11 -
Sometimes we want more control over how the data is split. For instance, if you have imbalanced groups, stratified k-fold ensures each fold has a similar distribution:
In [3]: strata = (data.observed_data["obs"] > 5).astype(int) ...: kfold_strat = loo_kfold(data, wrapper, k=4, stratify_by=strata) ...: kfold_strat ...: Out[3]: Computed from 4-fold cross validation. Estimate SE elpd_kfold -31.40 0.74 p_kfold 1.56 -
Moreover, sometimes we want to group observations together. For instance, if we have repeated measurements from the same subject, we can group by subject:
In [4]: groups = xr.DataArray([1, 1, 2, 2, 3, 3, 4, 4], dims="school") ...: kfold_group = loo_kfold(data, wrapper, k=4, group_by=groups) ...: kfold_group ...: Out[4]: Computed from 4-fold cross validation. Estimate SE elpd_kfold -31.83 0.75 p_kfold 1.99 -