Source code for arviz_stats.loo.loo_moment_match

"""Compute moment matching for problematic observations in PSIS-LOO-CV."""

import warnings
from collections import namedtuple
from copy import deepcopy

import arviz_base as azb
import numpy as np
import xarray as xr
from arviz_base import dataset_to_dataarray, rcParams
from xarray_einstats.stats import logsumexp

from arviz_stats.loo.helper_loo import (
    _get_log_likelihood_i,
    _get_r_eff,
    _prepare_loo_inputs,
    _shift,
    _shift_and_cov,
    _shift_and_scale,
    _warn_pareto_k,
)
from arviz_stats.sampling_diagnostics import ess
from arviz_stats.utils import ELPDData

SplitMomentMatch = namedtuple("SplitMomentMatch", ["lwi", "lwfi", "log_liki", "reff"])
UpdateQuantities = namedtuple("UpdateQuantities", ["lwi", "lwfi", "ki", "kfi", "log_liki"])
LooMomentMatchResult = namedtuple(
    "LooMomentMatchResult",
    ["final_log_liki", "final_lwi", "final_ki", "kfs_i", "reff_i", "original_ki", "i"],
)


[docs] def loo_moment_match( data, loo_orig, log_prob_upars_fn, log_lik_i_upars_fn, upars=None, var_name=None, reff=None, max_iters=30, k_threshold=None, split=True, cov=False, pointwise=None, ): r"""Compute moment matching for problematic observations in PSIS-LOO-CV. Adjusts the results of a previously computed Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) object by applying a moment matching algorithm to observations with high Pareto k diagnostic values. The moment matching algorithm iteratively adjusts the posterior draws in the unconstrained parameter space to better approximate the leave-one-out posterior. The moment matching algorithm is described in [1]_ and the PSIS-LOO-CV method is described in [2]_ and [3]_. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the posterior and the log_likelihood groups. loo_orig : ELPDData An existing ELPDData object from a previous `loo` result. Must contain pointwise Pareto k values (`pointwise=True` must have been used). log_prob_upars_fn : Callable[[DataArray], DataArray] A function that takes the unconstrained parameter draws and returns a :class:`~xarray.DataArray` containing the log probability density of the full posterior distribution evaluated at each unconstrained parameter draw. The returned DataArray must have dimensions `chain`, `draw`. log_lik_i_upars_fn : Callable[[DataArray, int], DataArray] A function that takes the unconstrained parameter draws and the integer index `i` of the left-out observation. It should return a :class:`~xarray.DataArray` containing the log-likelihood of the left-out observation `i` evaluated at each unconstrained parameter draw. The returned DataArray must have dimensions `chain`, `draw`. upars : DataArray, optional Posterior draws transformed to the unconstrained parameter space. Must have `chain` and `draw` dimensions, plus one additional dimension containing all parameters. Parameter names can be provided as coordinate values on this dimension. If not provided, will attempt to use the `unconstrained_posterior` group from the input data if available. var_name : str, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation. reff: float, optional Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. Computed from trace by default. max_iters : int, default 30 Maximum number of moment matching iterations for each problematic observation. k_threshold : float, optional Threshold value for Pareto k values above which moment matching is applied. Defaults to :math:`\min(1 - 1/\log_{10}(S), 0.7)`, where S is the number of samples. split : bool, default True If True, only transform half of the draws and use multiple importance sampling to combine them with untransformed draws. cov : bool, default False If True, match the covariance structure during the transformation, in addition to the mean and marginal variances. Ignored if ``split=False``. pointwise: bool, optional If True, the pointwise predictive accuracy will be returned. Defaults to ``rcParams["stats.ic_pointwise"]``. Moment matching always requires pointwise data from `loo_orig`. This argument controls whether the returned object includes pointwise data. Returns ------- ELPDData Object with the following attributes: - **elpd**: expected log pointwise predictive density - **se**: standard error of the elpd - **p**: effective number of parameters - **n_samples**: number of samples - **n_data_points**: number of data points - **warning**: True if the estimated shape parameter of Pareto distribution is greater than ``good_k``. - **elp_i**: :class:`~xarray.DataArray` with the pointwise predictive accuracy, only if ``pointwise=True`` - **pareto_k**: array of Pareto shape values, only if ``pointwise=True`` - **good_k**: For a sample size S, the threshold is computed as ``min(1 - 1/log10(S), 0.7)`` - **approx_posterior**: True if approximate posterior was used. Examples -------- Moment matching can improve PSIS-LOO-CV estimates for observations with high Pareto k values without having to refit the model for each problematic observation. We will use the non-centered eight schools data which has 1 problematic observation: .. ipython:: :okwarning: In [1]: import arviz_base as az ...: from arviz_stats import loo, loo_moment_match ...: import numpy as np ...: import xarray as xr ...: from scipy import stats ...: ...: idata = az.load_arviz_data("non_centered_eight") ...: loo_orig = loo(idata, pointwise=True, var_name="obs") ...: loo_orig For moment matching, we need the unconstrained parameters and two functions for the log probability and pointwise log-likelihood computations: .. ipython:: In [3]: posterior = idata.posterior ...: theta_t = posterior.theta_t.values ...: mu = posterior.mu.values[:, :, np.newaxis] ...: log_tau = np.log(posterior.tau.values)[:, :, np.newaxis] ...: ...: upars = np.concatenate([theta_t, mu, log_tau], axis=2) ...: param_names = [f"theta_t_{i}" for i in range(8)] + ["mu", "log_tau"] ...: ...: upars = xr.DataArray( ...: upars, ...: dims=["chain", "draw", "upars_dim"], ...: coords={ ...: "chain": posterior.chain, ...: "draw": posterior.draw, ...: "upars_dim": param_names ...: } ...: ) ...: ...: def log_prob_upars(upars): ...: theta_tilde = upars.sel(upars_dim=[f"theta_t_{i}" for i in range(8)]) ...: mu = upars.sel(upars_dim="mu") ...: log_tau = upars.sel(upars_dim="log_tau") ...: tau = np.exp(log_tau) ...: ...: log_prob = stats.norm(0, 5).logpdf(mu.values) ...: log_prob += stats.halfcauchy(0, 5).logpdf(tau.values) ...: log_prob += log_tau.values ...: log_prob += stats.norm(0, 1).logpdf(theta_tilde.values).sum(axis=-1) ...: ...: return xr.DataArray( ...: log_prob, ...: dims=["chain", "draw"], ...: coords={"chain": upars.chain, "draw": upars.draw} ...: ) ...: ...: def log_lik_i_upars(upars, i): ...: sigmas = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0] ...: theta_tilde_i = upars.sel(upars_dim=f"theta_t_{i}") ...: mu = upars.sel(upars_dim="mu") ...: tau = np.exp(upars.sel(upars_dim="log_tau")) ...: theta_i = mu + tau * theta_tilde_i ...: y_i = idata.observed_data.obs.values[i] ...: log_lik = stats.norm(theta_i.values, sigmas[i]).logpdf(y_i) ...: ...: return xr.DataArray( ...: log_lik, ...: dims=["chain", "draw"], ...: coords={"chain": upars.chain, "draw": upars.draw} ...: ) We can now apply moment matching using the split transformation and covariance matching. We can see that all Pareto :math:`k` values are now below the threshold and the ELPD is slightly improved: .. ipython:: :okwarning: In [4]: loo_mm = loo_moment_match( ...: idata, ...: loo_orig, ...: upars=upars, ...: log_prob_upars_fn=log_prob_upars, ...: log_lik_i_upars_fn=log_lik_i_upars, ...: var_name="obs", ...: k_threshold=0.7, ...: split=True, ...: cov=False, ...: ) ...: loo_mm Notes ----- The moment matching algorithm considers three affine transformations of the posterior draws. For a specific draw :math:`\theta^{(s)}`, a generic affine transformation includes a square matrix :math:`\mathbf{A}` representing a linear map and a vector :math:`\mathbf{b}` representing a translation such that .. math:: T : \theta^{(s)} \mapsto \mathbf{A}\theta^{(s)} + \mathbf{b} =: \theta^{*{(s)}}. The first transformation, :math:`T_1`, is a translation that matches the mean of the sample to its importance weighted mean given by .. math:: \mathbf{\theta^{*{(s)}}} = T_1(\mathbf{\theta^{(s)}}) = \mathbf{\theta^{(s)}} - \bar{\theta} + \bar{\theta}_w, where :math:`\bar{\theta}` is the mean of the sample and :math:`\bar{\theta}_w` is the importance weighted mean of the sample. The second transformation, :math:`T_2`, is a scaling that matches the marginal variances in addition to the means given by .. math:: \mathbf{\theta^{*{(s)}}} = T_2(\mathbf{\theta^{(s)}}) = \mathbf{v}^{1/2}_w \circ \mathbf{v}^{-1/2} \circ (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w, where :math:`\mathbf{v}` and :math:`\mathbf{v}_w` are the sample and weighted variances, and :math:`\circ` denotes the pointwise product of the elements of two vectors. The third transformation, :math:`T_3`, is a covariance transformation that matches the covariance matrix of the sample to its importance weighted covariance matrix given by .. math:: \mathbf{\theta^{*{(s)}}} = T_3(\mathbf{\theta^{(s)}}) = \mathbf{L}_w \mathbf{L}^{-1} (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w, where :math:`\mathbf{L}` and :math:`\mathbf{L}_w` are the Cholesky decompositions of the covariance matrix and the weighted covariance matrix, respectively, e.g., .. math:: \mathbf{LL}^T = \mathbf{\Sigma} = \frac{1}{S} \sum_{s=1}^S (\mathbf{\theta^{(s)}} - \bar{\theta}) (\mathbf{\theta^{(s)}} - \bar{\theta})^T and .. math:: \mathbf{L}_w \mathbf{L}_w^T = \mathbf{\Sigma}_w = \frac{\frac{1}{S} \sum_{s=1}^S w^{(s)} (\mathbf{\theta^{(s)}} - \bar{\theta}_w) (\mathbf{\theta^{(s)}} - \bar{\theta}_w)^T}{\sum_{s=1}^S w^{(s)}}. We iterate on :math:`T_1` repeatedly and move onto :math:`T_2` and :math:`T_3` only if :math:`T_1` fails to yield a Pareto-k statistic below the threshold. See Also -------- loo : Standard PSIS-LOO-CV. reloo : Exact re-fitting for problematic observations. References ---------- .. [1] Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). Implicitly Adaptive Importance Sampling. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850. .. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if not isinstance(loo_orig, ELPDData): raise TypeError("loo_orig must be an ELPDData object.") if loo_orig.pareto_k is None or loo_orig.elpd_i is None: raise ValueError( "Moment matching requires pointwise LOO results with Pareto k values. " "Please compute the initial LOO with pointwise=True." ) sample_dims = ["chain", "draw"] if upars is None: if hasattr(data, "unconstrained_posterior"): upars_ds = azb.get_unconstrained_samples(data, return_dataset=True) upars = dataset_to_dataarray( upars_ds, sample_dims=sample_dims, new_dim="unconstrained_parameter" ) else: raise ValueError( "upars must be provided or data must contain an 'unconstrained_posterior' group." ) if not isinstance(upars, xr.DataArray): raise TypeError("upars must be a DataArray.") if not all(dim_name in upars.dims for dim_name in sample_dims): raise ValueError(f"upars must have dimensions {sample_dims}.") param_dim_list = [dim for dim in upars.dims if dim not in sample_dims] if len(param_dim_list) == 0: param_dim_name = "upars_dim" upars = upars.expand_dims(dim={param_dim_name: 1}) elif len(param_dim_list) == 1: param_dim_name = param_dim_list[0] else: raise ValueError("upars must have at most one dimension besides 'chain' and 'draw'.") loo_data = deepcopy(loo_orig) loo_data.method = "loo_moment_match" pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise loo_inputs = _prepare_loo_inputs(data, var_name) log_likelihood = loo_inputs.log_likelihood obs_dims = loo_inputs.obs_dims n_samples = loo_inputs.n_samples var_name = loo_inputs.var_name n_params = upars.sizes[param_dim_name] n_data_points = loo_orig.n_data_points if reff is None: reff = _get_r_eff(data, n_samples) try: orig_log_prob = log_prob_upars_fn(upars) if not isinstance(orig_log_prob, xr.DataArray): raise TypeError("log_prob_upars_fn must return a DataArray.") if not all(dim in orig_log_prob.dims for dim in sample_dims): raise ValueError(f"Original log probability must have dimensions {sample_dims}.") if len(orig_log_prob.dims) != len(sample_dims): raise ValueError( f"Original log probability should only have dimensions {sample_dims}, " f"found {orig_log_prob.dims}" ) except Exception as e: raise ValueError(f"Error executing log_prob_upars_fn: {e}") from e if k_threshold is None: k_threshold = min(1 - 1 / np.log10(n_samples), 0.7) if n_samples > 1 else 0.7 ks = loo_data.pareto_k.stack(__obs__=obs_dims).transpose("__obs__").values bad_obs_indices = np.where(ks > k_threshold)[0] if len(bad_obs_indices) == 0: warnings.warn("No Pareto k values exceed the threshold. Returning original LOO data.") if not pointwise: loo_data.elpd_i = None loo_data.pareto_k = None if hasattr(loo_data, "p_loo_i"): loo_data.p_loo_i = None return loo_data lpd = logsumexp(log_likelihood, dims=sample_dims, b=1 / n_samples) loo_data.p_loo_i = lpd - loo_data.elpd_i kfs = np.zeros(n_data_points) # Moment matching algorithm for i in bad_obs_indices: mm_result = _loo_moment_match_i( i=i, upars=upars, log_likelihood=log_likelihood, log_prob_upars_fn=log_prob_upars_fn, log_lik_i_upars_fn=log_lik_i_upars_fn, max_iters=max_iters, k_threshold=k_threshold, split=split, cov=cov, orig_log_prob=orig_log_prob, ks=ks, sample_dims=sample_dims, obs_dims=obs_dims, n_samples=n_samples, n_params=n_params, param_dim_name=param_dim_name, ) kfs[i] = mm_result.kfs_i if mm_result.final_ki < mm_result.original_ki: new_elpd_i = logsumexp( mm_result.final_log_liki + mm_result.final_lwi, dims=sample_dims ).item() original_log_liki = _get_log_likelihood_i(log_likelihood, i, obs_dims) _update_loo_data_i( loo_data, i, new_elpd_i, mm_result.final_ki, mm_result.final_log_liki, sample_dims, obs_dims, n_samples, original_log_liki, suppress_warnings=True, ) else: warnings.warn( f"Observation {i}: Moment matching did not improve k " f"({mm_result.original_ki:.2f} -> {mm_result.final_ki:.2f}). Reverting.", UserWarning, stacklevel=2, ) if hasattr(loo_orig, "p_loo_i") and loo_orig.p_loo_i is not None: if len(obs_dims) == 1: idx_dict = {obs_dims[0]: i} else: coords = np.unravel_index(i, tuple(loo_data.elpd_i.sizes[d] for d in obs_dims)) idx_dict = dict(zip(obs_dims, coords)) loo_data.p_loo_i[idx_dict] = loo_orig.p_loo_i[idx_dict] final_ks = loo_data.pareto_k.stack(__obs__=obs_dims).transpose("__obs__").values if np.any(final_ks[bad_obs_indices] > k_threshold): warnings.warn( f"After Moment Matching, {np.sum(final_ks > k_threshold)} observations still have " f"Pareto k > {k_threshold:.2f}.", UserWarning, stacklevel=2, ) if not split and np.any(kfs > k_threshold): warnings.warn( "The accuracy of self-normalized importance sampling may be bad. " "Setting the argument 'split' to 'True' will likely improve accuracy.", UserWarning, stacklevel=2, ) elpd_raw = logsumexp(log_likelihood, dims=sample_dims, b=1 / n_samples).sum().values loo_data.p = elpd_raw - loo_data.elpd if not pointwise: loo_data.elpd_i = None loo_data.pareto_k = None if hasattr(loo_data, "p_loo_i"): loo_data.p_loo_i = None return loo_data
def _split_moment_match( upars, cov, total_shift, total_scaling, total_mapping, i, reff, log_prob_upars_fn, log_lik_i_upars_fn, ): r"""Split moment matching importance sampling for PSIS-LOO-CV. Applies affine transformations based on the total moment matching transformation to half of the posterior draws, leaving the other half unchanged. These approximations to the leave-one-out posterior are then combined using multiple importance sampling. Based on the implicit adaptive importance sampling algorithm of [1]_ and the PSIS-LOO-CV method of [2]_ and [3]_. Parameters ---------- upars : DataArray A DataArray representing the posterior draws of the model parameters in the unconstrained space. Must contain the dimensions `chain` and `draw` and a final dimension representing the different unconstrained parameters. cov : bool Whether to match the full covariance matrix of the samples (True) or just the marginal variances (False). Using the full covariance is more computationally expensive. total_shift : ndarray Vector containing the total shift (translation) applied to the parameters. Shape should match the parameter dimension of ``upars``. total_scaling : ndarray Vector containing the total scaling factors for the marginal variances. Shape should match the parameter dimension of ``upars``. total_mapping : ndarray Square matrix representing the linear transformation applied to the covariance matrix. Shape should be (d, d) where d is the parameter dimension. i : int Index of the specific observation to be left out for computing leave-one-out likelihood. reff : float Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. log_prob_upars_fn : Callable[[DataArray], DataArray] A function that computes the log probability density of the *full posterior* distribution evaluated at given unconstrained parameter values (as a DataArray). Input and Output must have dimensions `chain` and `draw`. log_lik_i_upars_fn : Callable[[DataArray, int], DataArray] A function that computes the log-likelihood of the *left-out observation* `i` evaluated at given unconstrained parameter values (as a DataArray). Input and Output must have dimensions `chain` and `draw`. Returns ------- SplitMomentMatch A namedtuple containing: - lwi: Updated log importance weights for each sample - lwfi: Updated log importance weights for full distribution - log_liki: Updated log likelihood values for the specific observation - reff: Relative MCMC efficiency (updated based on the split samples) References ---------- .. [1] Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). *Implicitly Adaptive Importance Sampling*. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850. .. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if not isinstance(upars, xr.DataArray): raise TypeError("upars must be a DataArray.") sample_dims = ["chain", "draw"] param_dim_list = [dim for dim in upars.dims if dim not in sample_dims] if len(param_dim_list) != 1: raise ValueError("upars must have exactly one dimension besides chain and draw.") param_dim = param_dim_list[0] if not all(dim in upars.dims for dim in sample_dims): raise ValueError( f"Required sample dimensions {sample_dims} not found in upars dimensions {upars.dims}" ) dim = upars.sizes[param_dim] n_samples = upars.sizes["chain"] * upars.sizes["draw"] n_samples_half = n_samples // 2 upars_stacked = upars.stack(__sample__=sample_dims).transpose("__sample__", param_dim) mean_original = upars_stacked.mean(dim="__sample__") if total_shift is None or total_shift.size == 0: total_shift = np.zeros(dim) if total_scaling is None or total_scaling.size == 0: total_scaling = np.ones(dim) if total_mapping is None or total_mapping.size == 0: total_mapping = np.eye(dim) # Forward transformation upars_trans = upars_stacked - mean_original upars_trans = upars_trans * xr.DataArray(total_scaling, dims=param_dim) if cov and dim > 0: upars_trans = xr.DataArray( upars_trans.data @ total_mapping.T, coords=upars_trans.coords, dims=upars_trans.dims, ) upars_trans = upars_trans + (xr.DataArray(total_shift, dims=param_dim) + mean_original) # Inverse transformation upars_trans_inv = upars_stacked - (xr.DataArray(total_shift, dims=param_dim) + mean_original) if cov and dim > 0: try: inv_mapping_t = np.linalg.inv(total_mapping.T) upars_trans_inv = xr.DataArray( upars_trans_inv.data @ inv_mapping_t, coords=upars_trans_inv.coords, dims=upars_trans_inv.dims, ) except np.linalg.LinAlgError: warnings.warn("Could not invert mapping matrix. Using identity.", UserWarning) upars_trans_inv = upars_trans_inv / xr.DataArray(total_scaling, dims=param_dim) upars_trans_inv = upars_trans_inv + (mean_original - xr.DataArray(total_shift, dims=param_dim)) upars_trans_half = upars_stacked.copy(deep=True).unstack("__sample__") upars_trans_half = upars_trans_half.transpose(*sample_dims, param_dim) upars_trans_half.values.reshape(-1, dim)[:n_samples_half] = upars_trans.values.reshape(-1, dim)[ :n_samples_half ] upars_trans_half_inv = upars_stacked.copy(deep=True).unstack("__sample__") upars_trans_half_inv = upars_trans_half_inv.transpose(*sample_dims, param_dim) upars_trans_half_inv.values.reshape(-1, dim)[n_samples_half:] = upars_trans_inv.values.reshape( -1, dim )[n_samples_half:] try: log_prob_half_trans = log_prob_upars_fn(upars_trans_half) log_prob_half_trans_inv = log_prob_upars_fn(upars_trans_half_inv) except Exception as e: raise ValueError( f"Could not compute log probabilities for transformed parameters: {e}" ) from e try: log_liki_half = log_lik_i_upars_fn(upars_trans_half, i) if not all(dim in log_liki_half.dims for dim in sample_dims) or len( log_liki_half.dims ) != len(sample_dims): raise ValueError( f"log_lik_i_upars_fn must return a DataArray with dimensions {sample_dims}" ) if ( log_liki_half.sizes["chain"] != upars.sizes["chain"] or log_liki_half.sizes["draw"] != upars.sizes["draw"] ): raise ValueError( "log_lik_i_upars_fn output shape does not match input sample dimensions" ) except Exception as e: raise ValueError(f"Could not compute log likelihood for observation {i}: {e}") from e # Jacobian adjustment log_jacobian_det = 0.0 if dim > 0: log_jacobian_det = -np.sum(np.log(np.abs(total_scaling))) try: det_val = np.linalg.det(total_mapping) if det_val > 0: log_jacobian_det -= np.log(det_val) else: log_jacobian_det -= np.inf except np.linalg.LinAlgError: log_jacobian_det -= np.inf log_prob_half_trans_inv_adj = log_prob_half_trans_inv + log_jacobian_det # Multiple importance sampling use_forward_log_prob = log_prob_half_trans > log_prob_half_trans_inv_adj raw_log_weights_half = -log_liki_half + log_prob_half_trans log_sum_terms = xr.where( use_forward_log_prob, log_prob_half_trans + xr.ufuncs.log1p(np.exp(log_prob_half_trans_inv_adj - log_prob_half_trans)), log_prob_half_trans_inv_adj + xr.ufuncs.log1p(np.exp(log_prob_half_trans - log_prob_half_trans_inv_adj)), ) raw_log_weights_half -= log_sum_terms raw_log_weights_half = xr.where(np.isnan(raw_log_weights_half), -np.inf, raw_log_weights_half) raw_log_weights_half = xr.where( np.isposinf(raw_log_weights_half), -np.inf, raw_log_weights_half ) # PSIS smoothing for half posterior lwi_psis_da, _ = raw_log_weights_half.azstats.psislw(r_eff=reff, dim=sample_dims) lr_full = lwi_psis_da + log_liki_half lr_full = xr.where(np.isnan(lr_full) | (np.isinf(lr_full) & (lr_full > 0)), -np.inf, lr_full) # PSIS smoothing for full posterior lwfi_psis_da, _ = lr_full.azstats.psislw(r_eff=reff, dim=sample_dims) n_chains = upars.sizes["chain"] if n_chains == 1: reff_updated = reff else: # Calculate ESS for each half of the data log_liki_half_1 = log_liki_half.isel( chain=slice(None), draw=slice(0, n_samples_half // n_chains) ) log_liki_half_2 = log_liki_half.isel( chain=slice(None), draw=slice(n_samples_half // n_chains, None) ) liki_half_1 = np.exp(log_liki_half_1) liki_half_2 = np.exp(log_liki_half_2) ess_1 = liki_half_1.azstats.ess(method="mean") ess_2 = liki_half_2.azstats.ess(method="mean") ess_1_value = float(ess_1.values) if hasattr(ess_1, "values") else float(ess_1) ess_2_value = float(ess_2.values) if hasattr(ess_2, "values") else float(ess_2) n_samples_1 = log_liki_half_1.size n_samples_2 = log_liki_half_2.size r_eff_1 = ess_1_value / n_samples_1 r_eff_2 = ess_2_value / n_samples_2 reff_updated = min(r_eff_1, r_eff_2) return SplitMomentMatch( lwi=lwi_psis_da, lwfi=lwfi_psis_da, log_liki=log_liki_half, reff=reff_updated, ) def _loo_moment_match_i( i, upars, log_likelihood, log_prob_upars_fn, log_lik_i_upars_fn, max_iters, k_threshold, split, cov, orig_log_prob, ks, sample_dims, obs_dims, n_samples, n_params, param_dim_name, ): """Compute moment matching for a single observation.""" n_chains = upars.sizes["chain"] n_draws = upars.sizes["draw"] log_liki = _get_log_likelihood_i(log_likelihood, i, obs_dims) liki = np.exp(log_liki) liki_reshaped = liki.values.reshape(n_chains, n_draws).T ess_val = ess(liki_reshaped, method="mean").item() reff_i = ess_val / n_samples if n_samples > 0 else 1.0 log_ratio_i_init = -log_liki lwi, ki_tuple = log_ratio_i_init.azstats.psislw(r_eff=reff_i, dim=sample_dims) ki = ki_tuple[0].item() if isinstance(ki_tuple, tuple) else ki_tuple.item() original_ki = ks[i] upars_i = upars.copy(deep=True) total_shift = np.zeros(upars_i.sizes[param_dim_name]) total_scaling = np.ones(upars_i.sizes[param_dim_name]) total_mapping = np.eye(upars_i.sizes[param_dim_name]) iterind = 1 transformations_applied = False kfs_i = 0 while iterind <= max_iters and ki > k_threshold: if iterind == max_iters: warnings.warn( f"Maximum number of moment matching iterations ({max_iters}) reached " f"for observation {i}. Final Pareto k is {ki:.2f}.", UserWarning, stacklevel=2, ) break # Try Mean Shift try: shift_res = _shift(upars_i, lwi) quantities_i = _update_quantities_i( shift_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = shift_res.upars total_shift = total_shift + shift_res.shift transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during mean shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break # Try Scale Shift try: scale_res = _shift_and_scale(upars_i, lwi) quantities_i = _update_quantities_i( scale_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = scale_res.upars total_shift = total_shift + scale_res.shift total_scaling = total_scaling * scale_res.scaling transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during scale shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break # Try Covariance Shift if cov and n_samples >= 10 * n_params: try: cov_res = _shift_and_cov(upars_i, lwi) quantities_i = _update_quantities_i( cov_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = cov_res.upars total_shift = total_shift + cov_res.shift total_mapping = cov_res.mapping @ total_mapping transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during covariance shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break break if split and transformations_applied: try: split_res = _split_moment_match( upars=upars, cov=cov, total_shift=total_shift, total_scaling=total_scaling, total_mapping=total_mapping, i=i, reff=reff_i, log_prob_upars_fn=log_prob_upars_fn, log_lik_i_upars_fn=log_lik_i_upars_fn, ) final_log_liki = split_res.log_liki final_lwi = split_res.lwi _, ki_split_tuple = split_res.lwi.azstats.psislw(r_eff=split_res.reff, dim=sample_dims) ki_split = ( ki_split_tuple[0].item() if isinstance(ki_split_tuple, tuple) else ki_split_tuple.item() ) final_ki = ki_split _, kf_tuple = split_res.lwfi.azstats.psislw(r_eff=split_res.reff, dim=sample_dims) kfs_i = kf_tuple[0].item() if isinstance(kf_tuple, tuple) else kf_tuple.item() reff_i = split_res.reff if ki_split > ki and ki <= k_threshold: warnings.warn( f"Split transformation increased Pareto k for observation {i} " f"({ki:.2f} -> {ki_split:.2f}). This may indicate numerical issues.", UserWarning, stacklevel=2, ) except RuntimeError as e: warnings.warn( f"Error during split moment matching for observation {i}: {e}. " "Using non-split transformation result.", UserWarning, stacklevel=2, ) final_log_liki = log_liki final_lwi = lwi final_ki = ki else: final_log_liki = log_liki final_lwi = lwi final_ki = ki if transformations_applied: liki_final = np.exp(final_log_liki) liki_final_reshaped = liki_final.values.reshape(n_chains, n_draws).T ess_val_final = ess(liki_final_reshaped, method="mean").item() reff_i = ess_val_final / n_samples if n_samples > 0 else 1.0 return LooMomentMatchResult( final_log_liki=final_log_liki, final_lwi=final_lwi, final_ki=final_ki, kfs_i=kfs_i, reff_i=reff_i, original_ki=original_ki, i=i, ) def _update_loo_data_i( loo_data, i, new_elpd_i, new_pareto_k, log_liki, sample_dims, obs_dims, n_samples, original_log_liki=None, suppress_warnings=False, ): """Update the ELPDData object for a single observation.""" if loo_data.elpd_i is None or loo_data.pareto_k is None: raise ValueError("loo_data must contain pointwise elpd_i and pareto_k values.") lpd_i_log_lik = original_log_liki if original_log_liki is not None else log_liki lpd_i = logsumexp(lpd_i_log_lik, dims=sample_dims, b=1 / n_samples).item() p_loo_i = lpd_i - new_elpd_i if len(obs_dims) == 1: idx_dict = {obs_dims[0]: i} else: coords = np.unravel_index(i, tuple(loo_data.elpd_i.sizes[d] for d in obs_dims)) idx_dict = dict(zip(obs_dims, coords)) loo_data.elpd_i[idx_dict] = new_elpd_i loo_data.pareto_k[idx_dict] = new_pareto_k if not hasattr(loo_data, "p_loo_i") or loo_data.p_loo_i is None: loo_data.p_loo_i = xr.full_like(loo_data.elpd_i, np.nan) loo_data.p_loo_i[idx_dict] = p_loo_i loo_data.elpd = np.nansum(loo_data.elpd_i.values) loo_data.se = np.sqrt(loo_data.n_data_points * np.nanvar(loo_data.elpd_i.values)) loo_data.warning, loo_data.good_k = _warn_pareto_k( loo_data.pareto_k.values[~np.isnan(loo_data.pareto_k.values)], loo_data.n_samples, suppress=suppress_warnings, ) def _update_quantities_i( upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ): """Update the moment matching quantities for a single observation.""" log_prob_new = log_prob_upars_fn(upars) log_liki_new = log_lik_i_upars_fn(upars, i) log_ratio_i = -log_liki_new + log_prob_new - orig_log_prob lwi_new, ki_new_tuple = log_ratio_i.azstats.psislw(r_eff=reff_i, dim=sample_dims) ki_new = ki_new_tuple[0].item() if isinstance(ki_new_tuple, tuple) else ki_new_tuple.item() log_ratio_full = log_prob_new - orig_log_prob lwfi_new, kfi_new_tuple = log_ratio_full.azstats.psislw(r_eff=reff_i, dim=sample_dims) kfi_new = kfi_new_tuple[0].item() if isinstance(kfi_new_tuple, tuple) else kfi_new_tuple.item() return UpdateQuantities( lwi=lwi_new, lwfi=lwfi_new, ki=ki_new, kfi=kfi_new, log_liki=log_liki_new, )