Source code for arviz_stats.loo.loo_kfold

"""K-fold cross-validation."""

import numpy as np
from arviz_base import rcParams

from arviz_stats.loo.helper_loo_kfold import (
    _combine_fold_elpds,
    _compute_kfold_results,
    _prepare_kfold_inputs,
)
from arviz_stats.utils import ELPDData


[docs] def loo_kfold( data, wrapper, pointwise=None, var_name=None, k=10, folds=None, stratify_by=None, group_by=None, save_fits=False, ): """Perform exact K-fold cross-validation. K-fold cross-validation evaluates model predictive accuracy by partitioning the data into K complementary subsets (folds), then iteratively refitting the model K times, each time holding out one fold as a test set and training on the remaining K-1 folds. This method provides an unbiased estimate of model performance by ensuring each observation is used exactly once for testing. Unlike PSIS-LOO-CV (Pareto-smoothed importance sampling leave-one-out cross-validation), which approximates cross-validation efficiently, K-fold requires actual model refitting but yields exact results. Parameters ---------- data : DataTree or InferenceData Input data containing the posterior and log_likelihood groups from the full model fit. wrapper : SamplingWrapper An instance of SamplingWrapper class handling model refitting. The wrapper must implement the following methods: sel_observations, sample, get_inference_data, and log_likelihood__i. pointwise : bool, optional If True, return pointwise estimates. Defaults to ``rcParams["stats.ic_pointwise"]``. var_name : str, optional The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for computation. k : int, default=10 The number of folds for cross-validation. The data will be partitioned into k subsets of equal (or approximately equal) size. folds : array or DataArray, optional Manual fold assignments (1 to k) for each observation. For example, [1,1,2,2,3,3,4,4] assigns first two obs to fold 1, next two to fold 2, etc. If not provided, creates k random folds of equal size. Cannot be used together with `stratify_by` and `group_by`. stratify_by : array or DataArray, optional Maintains class proportions across folds. For example, [0,0,1,1,0,0,1,1] ensures each fold has 50% class 0 and 50% class 1. Cannot be used together with `folds` and `group_by`. group_by : array or DataArray, optional Grouping variable to keep related observations together in the same fold. For example, [1,1,2,2,3,3] keeps all obs from group 1 in one fold, group 2 in another, etc. Useful for repeated measures or clustered data. Cannot be used together with `folds` and `stratify_by`. save_fits : bool, default=False If True, store the fitted models and fold indices in the returned object. Returns ------- ELPDData Object with the following attributes: - **elpd**: expected log pointwise predictive density - **se**: standard error of the elpd - **p**: effective number of parameters - **n_samples**: number of samples per fold - **n_data_points**: number of data points - **warning**: True if any issues occurred during fitting - **elpd_i**: pointwise predictive accuracy (if ``pointwise=True``) - **p_kfold_i**: pointwise effective number of parameters (if ``pointwise=True``) - **pareto_k**: None (not applicable for k-fold) - **scale**: "log" Additional attributes when ``save_fits=True``: - **fold_fits**: Dictionary containing fitted models for each fold - **fold_indices**: Dictionary containing test indices for each fold Examples -------- Unlike PSIS-LOO (which approximates LOO-CV), k-fold cross-validation refits the model k times. So we need to tell ``loo_kfold`` how to refit the model. This is done by creating an instance of the ``SamplingWrapper`` class that implements four key methods: ``sel_observations``, ``sample``, ``get_inference_data``, and ``log_likelihood__i``. .. ipython:: In [1]: import numpy as np ...: import xarray as xr ...: from scipy import stats ...: from arviz_base import load_arviz_data, from_dict ...: from arviz_stats import loo_kfold ...: from arviz_stats.loo import SamplingWrapper ...: ...: class CenteredEightWrapper(SamplingWrapper): ...: def __init__(self, idata): ...: super().__init__(model=None, idata_orig=idata) ...: self.y_obs = idata.observed_data["obs"].values ...: self.sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) ...: ...: def sel_observations(self, idx): ...: all_idx = np.arange(len(self.y_obs)) ...: train_idx = np.setdiff1d(all_idx, idx) ...: ...: train_data = { ...: "y": self.y_obs[train_idx], ...: "sigma": self.sigma[train_idx], ...: "indices": train_idx ...: } ...: test_data = { ...: "y": self.y_obs[idx], ...: "sigma": self.sigma[idx], ...: "indices": idx ...: } ...: return train_data, test_data ...: ...: def sample(self, modified_observed_data): ...: # (Simplified version where we normally would use the actual sampler) ...: train_y = modified_observed_data["y"] ...: n = 1000 ...: mu = np.random.normal(train_y.mean(), 5, n) ...: tau = np.abs(np.random.normal(10, 2, n)) ...: return {"mu": mu, "tau": tau} ...: ...: def get_inference_data(self, fitted_model): ...: posterior = { ...: "mu": fitted_model["mu"].reshape(1, -1), ...: "tau": fitted_model["tau"].reshape(1, -1) ...: } ...: return from_dict({"posterior": posterior}) ...: ...: def log_likelihood__i(self, excluded_obs, idata__i): ...: test_y = excluded_obs["y"] ...: test_sigma = excluded_obs["sigma"] ...: mu = idata__i.posterior["mu"].values.flatten() ...: tau = idata__i.posterior["tau"].values.flatten() ...: ...: var_total = tau[:, np.newaxis] ** 2 + test_sigma**2 ...: log_lik = stats.norm.logpdf( ...: test_y, loc=mu[:, np.newaxis], scale=np.sqrt(var_total) ...: ) ...: ...: dims = ["chain", "school", "draw"] ...: coords = {"school": excluded_obs["indices"]} ...: return xr.DataArray( ...: log_lik.T[np.newaxis, :, :], dims=dims, coords=coords ...: ) Now let's run k-fold cross-validation. With k=4, we'll refit the model 4 times, each time leaving out 2 schools for testing: .. ipython:: In [2]: data = load_arviz_data("centered_eight") ...: wrapper = CenteredEightWrapper(data) ...: kfold_results = loo_kfold(data, wrapper, k=4, pointwise=True) ...: kfold_results Sometimes we want more control over how the data is split. For instance, if you have imbalanced groups, stratified k-fold ensures each fold has a similar distribution: .. ipython:: In [3]: strata = (data.observed_data["obs"] > 5).astype(int) ...: kfold_strat = loo_kfold(data, wrapper, k=4, stratify_by=strata) ...: kfold_strat Moreover, sometimes we want to group observations together. For instance, if we have repeated measurements from the same subject, we can group by subject: .. ipython:: In [4]: groups = xr.DataArray([1, 1, 2, 2, 3, 3, 4, 4], dims="school") ...: kfold_group = loo_kfold(data, wrapper, k=4, group_by=groups) ...: kfold_group Notes ----- When K equals the number of observations, this becomes exact leave-one-out cross-validation. Note that :func:`arviz_stats.loo` provides a much more efficient approximation for that case and is recommended. See Also -------- loo : Pareto-smoothed importance sampling LOO-CV SamplingWrapper : Base class for implementing sampling wrappers References ---------- .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise kfold_inputs = _prepare_kfold_inputs(data, var_name, wrapper, k, folds, stratify_by, group_by) kfold_results = _compute_kfold_results(kfold_inputs, wrapper, save_fits) combined_results = _combine_fold_elpds([kfold_results.elpds], kfold_inputs.n_data_points) elpd_sum = combined_results["elpd_kfold"] se_elpd = combined_results["se_elpd_kfold"] p_sum = np.sum(kfold_results.ps) elpd_data = ELPDData( kind="loo_kfold", elpd=elpd_sum, se=se_elpd, p=p_sum, n_samples=kfold_inputs.n_samples, n_data_points=kfold_inputs.n_data_points, scale="log", warning=False, good_k=None, elpd_i=kfold_results.elpd_i if pointwise else None, pareto_k=None, n_folds=k, ) if save_fits: elpd_data.fold_fits = kfold_results.fold_fits if pointwise: elpd_data.p_kfold_i = kfold_results.p_kfold_i return elpd_data