"""K-fold cross-validation."""
import numpy as np
from arviz_base import rcParams
from arviz_stats.loo.helper_loo_kfold import (
_combine_fold_elpds,
_compute_kfold_results,
_prepare_kfold_inputs,
)
from arviz_stats.utils import ELPDData
[docs]
def loo_kfold(
data,
wrapper,
pointwise=None,
var_name=None,
k=10,
folds=None,
stratify_by=None,
group_by=None,
save_fits=False,
):
"""Perform exact K-fold cross-validation.
K-fold cross-validation evaluates model predictive accuracy by partitioning the data
into K complementary subsets (folds), then iteratively refitting the model K times,
each time holding out one fold as a test set and training on the remaining K-1 folds.
This method provides an unbiased estimate of model performance by ensuring each
observation is used exactly once for testing. Unlike PSIS-LOO-CV (Pareto-smoothed
importance sampling leave-one-out cross-validation), which approximates cross-validation
efficiently, K-fold requires actual model refitting but yields exact results.
Parameters
----------
data : DataTree or InferenceData
Input data containing the posterior and log_likelihood groups from the full model fit.
wrapper : SamplingWrapper
An instance of SamplingWrapper class handling model refitting. The wrapper must
implement the following methods: sel_observations, sample, get_inference_data,
and log_likelihood__i.
pointwise : bool, optional
If True, return pointwise estimates. Defaults to ``rcParams["stats.ic_pointwise"]``.
var_name : str, optional
The name of the variable in log_likelihood group storing the pointwise log
likelihood data to use for computation.
k : int, default=10
The number of folds for cross-validation. The data will be partitioned into k subsets
of equal (or approximately equal) size.
folds : array or DataArray, optional
Manual fold assignments (1 to k) for each observation. For example, [1,1,2,2,3,3,4,4]
assigns first two obs to fold 1, next two to fold 2, etc. If not provided, creates k
random folds of equal size. Cannot be used together with `stratify_by` and `group_by`.
stratify_by : array or DataArray, optional
Maintains class proportions across folds. For example, [0,0,1,1,0,0,1,1] ensures each
fold has 50% class 0 and 50% class 1. Cannot be used together with `folds` and `group_by`.
group_by : array or DataArray, optional
Grouping variable to keep related observations together in the same fold. For example,
[1,1,2,2,3,3] keeps all obs from group 1 in one fold, group 2 in another, etc. Useful
for repeated measures or clustered data. Cannot be used together with `folds` and
`stratify_by`.
save_fits : bool, default=False
If True, store the fitted models and fold indices in the returned object.
Returns
-------
ELPDData
Object with the following attributes:
- **elpd**: expected log pointwise predictive density
- **se**: standard error of the elpd
- **p**: effective number of parameters
- **n_samples**: number of samples per fold
- **n_data_points**: number of data points
- **warning**: True if any issues occurred during fitting
- **elpd_i**: pointwise predictive accuracy (if ``pointwise=True``)
- **p_kfold_i**: pointwise effective number of parameters (if ``pointwise=True``)
- **pareto_k**: None (not applicable for k-fold)
- **scale**: "log"
Additional attributes when ``save_fits=True``:
- **fold_fits**: Dictionary containing fitted models for each fold
- **fold_indices**: Dictionary containing test indices for each fold
Examples
--------
Unlike PSIS-LOO (which approximates LOO-CV), k-fold cross-validation refits
the model k times. So we need to tell ``loo_kfold`` how to refit the model.
This is done by creating an instance of the ``SamplingWrapper`` class that
implements four key methods: ``sel_observations``, ``sample``, ``get_inference_data``,
and ``log_likelihood__i``.
.. ipython::
In [1]: import numpy as np
...: import xarray as xr
...: from scipy import stats
...: from arviz_base import load_arviz_data, from_dict
...: from arviz_stats import loo_kfold
...: from arviz_stats.loo import SamplingWrapper
...:
...: class CenteredEightWrapper(SamplingWrapper):
...: def __init__(self, idata):
...: super().__init__(model=None, idata_orig=idata)
...: self.y_obs = idata.observed_data["obs"].values
...: self.sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
...:
...: def sel_observations(self, idx):
...: all_idx = np.arange(len(self.y_obs))
...: train_idx = np.setdiff1d(all_idx, idx)
...:
...: train_data = {
...: "y": self.y_obs[train_idx],
...: "sigma": self.sigma[train_idx],
...: "indices": train_idx
...: }
...: test_data = {
...: "y": self.y_obs[idx],
...: "sigma": self.sigma[idx],
...: "indices": idx
...: }
...: return train_data, test_data
...:
...: def sample(self, modified_observed_data):
...: # (Simplified version where we normally would use the actual sampler)
...: train_y = modified_observed_data["y"]
...: n = 1000
...: mu = np.random.normal(train_y.mean(), 5, n)
...: tau = np.abs(np.random.normal(10, 2, n))
...: return {"mu": mu, "tau": tau}
...:
...: def get_inference_data(self, fitted_model):
...: posterior = {
...: "mu": fitted_model["mu"].reshape(1, -1),
...: "tau": fitted_model["tau"].reshape(1, -1)
...: }
...: return from_dict({"posterior": posterior})
...:
...: def log_likelihood__i(self, excluded_obs, idata__i):
...: test_y = excluded_obs["y"]
...: test_sigma = excluded_obs["sigma"]
...: mu = idata__i.posterior["mu"].values.flatten()
...: tau = idata__i.posterior["tau"].values.flatten()
...:
...: var_total = tau[:, np.newaxis] ** 2 + test_sigma**2
...: log_lik = stats.norm.logpdf(
...: test_y, loc=mu[:, np.newaxis], scale=np.sqrt(var_total)
...: )
...:
...: dims = ["chain", "school", "draw"]
...: coords = {"school": excluded_obs["indices"]}
...: return xr.DataArray(
...: log_lik.T[np.newaxis, :, :], dims=dims, coords=coords
...: )
Now let's run k-fold cross-validation. With k=4, we'll refit the model 4 times,
each time leaving out 2 schools for testing:
.. ipython::
In [2]: data = load_arviz_data("centered_eight")
...: wrapper = CenteredEightWrapper(data)
...: kfold_results = loo_kfold(data, wrapper, k=4, pointwise=True)
...: kfold_results
Sometimes we want more control over how the data is split. For instance,
if you have imbalanced groups, stratified k-fold ensures each fold has
a similar distribution:
.. ipython::
In [3]: strata = (data.observed_data["obs"] > 5).astype(int)
...: kfold_strat = loo_kfold(data, wrapper, k=4, stratify_by=strata)
...: kfold_strat
Moreover, sometimes we want to group observations together. For instance,
if we have repeated measurements from the same subject, we can group by subject:
.. ipython::
In [4]: groups = xr.DataArray([1, 1, 2, 2, 3, 3, 4, 4], dims="school")
...: kfold_group = loo_kfold(data, wrapper, k=4, group_by=groups)
...: kfold_group
Notes
-----
When K equals the number of observations, this becomes exact leave-one-out
cross-validation. Note that :func:`arviz_stats.loo` provides a much more efficient
approximation for that case and is recommended.
See Also
--------
loo : Pareto-smoothed importance sampling LOO-CV
SamplingWrapper : Base class for implementing sampling wrappers
References
----------
.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
arXiv preprint https://arxiv.org/abs/1507.04544.
.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
arXiv preprint https://arxiv.org/abs/1507.02646
"""
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
kfold_inputs = _prepare_kfold_inputs(data, var_name, wrapper, k, folds, stratify_by, group_by)
kfold_results = _compute_kfold_results(kfold_inputs, wrapper, save_fits)
combined_results = _combine_fold_elpds([kfold_results.elpds], kfold_inputs.n_data_points)
elpd_sum = combined_results["elpd_kfold"]
se_elpd = combined_results["se_elpd_kfold"]
p_sum = np.sum(kfold_results.ps)
elpd_data = ELPDData(
kind="loo_kfold",
elpd=elpd_sum,
se=se_elpd,
p=p_sum,
n_samples=kfold_inputs.n_samples,
n_data_points=kfold_inputs.n_data_points,
scale="log",
warning=False,
good_k=None,
elpd_i=kfold_results.elpd_i if pointwise else None,
pareto_k=None,
n_folds=k,
)
if save_fits:
elpd_data.fold_fits = kfold_results.fold_fits
if pointwise:
elpd_data.p_kfold_i = kfold_results.p_kfold_i
return elpd_data