Source code for arviz_stats.psense

"""Power-scaling sensitivity diagnostics."""

import logging

import numpy as np
import xarray as xr
from arviz_base import convert_to_datatree, dataset_to_dataframe, extract
from arviz_base.labels import BaseLabeller

from arviz_stats.utils import get_log_likelihood_dataset, get_log_prior
from arviz_stats.validate import validate_dims

_log = logging.getLogger(__name__)

labeller = BaseLabeller()

__all__ = ["psense", "psense_summary"]


[docs] def psense( data, var_names=None, filter_vars=None, group="prior", coords=None, sample_dims=None, alphas=(0.99, 1.01), group_var_names=None, group_coords=None, ): """ Compute power-scaling sensitivity values. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the posterior and the log_likelihood and/or log_prior groups. var_names : list of str, optional Names of posterior variables to include in the power scaling sensitivity diagnostic filter_vars: {None, "like", "regex"}, default None Used for `var_names` only. If ``None`` (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. group : {"prior", "likelihood"}, default "prior" If "likelihood", the pointsize log likelihood values are retrieved from the ``log_likelihood`` group and added together. If "prior", the log prior values are retrieved from the ``log_prior`` group. coords : dict, optional Coordinates defining a subset over the posterior. Only these variables will be used when computing the prior sensitivity. sample_dims : str or sequence of hashable, optional Dimensions to reduce. Defaults to ``rcParams["data.sample_dims"]`` alphas : tuple Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01). group_var_names : str, optional Name of the prior or log likelihood variables to use group_coords : dict, optional Coordinates defining a subset over the group element for which to compute the prior sensitivity diagnostic. Returns ------- xarray.DataTree Returns dataTree of power-scaling sensitivity diagnostic values. Higher sensitivity values indicate greater sensitivity. Prior sensitivity above 0.05 indicates informative prior. Likelihood sensitivity below 0.05 indicates weak or non-informative likelihood. Notes ----- The diagnostic is computed by power-scaling either the prior or likelihood and determining the degree to which the posterior changes as described in [1]_. It uses Pareto-smoothed importance sampling to avoid refitting the model. References ---------- .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5 """ data = convert_to_datatree(data) dataset = extract( data, var_names=var_names, filter_vars=filter_vars, group="posterior", combined=False, keep_dataset=True, ) if coords is not None: dataset = dataset.sel(coords) lower_w, upper_w = _get_power_scale_weights( data, alphas=alphas, group=group, sample_dims=sample_dims, group_var_names=group_var_names, group_coords=group_coords, ) return dataset.azstats.power_scale_sense( lower_w=lower_w, upper_w=upper_w, lower_alpha=alphas[0], upper_alpha=alphas[1], sample_dims=sample_dims, )
[docs] def psense_summary( data, var_names=None, filter_vars=None, coords=None, sample_dims=None, threshold=0.05, alphas=(0.99, 1.01), prior_var_names=None, likelihood_var_names=None, prior_coords=None, likelihood_coords=None, round_to=3, ): """ Compute the prior/likelihood sensitivity based on power-scaling perturbations. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the posterior and the log_likelihood and/or log_prior groups. var_names : list of str, optional Names of posterior variables to include in the power scaling sensitivity diagnostic filter_vars: {None, "like", "regex"}, default None Used for `var_names` only. If ``None`` (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. coords : dict, optional Coordinates defining a subset over the posterior. Only these variables will be used when computing the prior sensitivity. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` threshold : float, optional Threshold value to determine the sensitivity diagnosis. Default is 0.05. alphas : tuple Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01). prior_var_names : str, optional Name of the log-prior variables to include in the power scaling sensitivity diagnostic likelihood_var_names : str, optional Name of the log-likelihood variables to include in the power scaling sensitivity diagnostic prior_coords : dict, optional Coordinates defining a subset over the group element for which to compute the log-prior sensitivity diagnostic likelihood_coords : dict, optional Coordinates defining a subset over the group element for which to compute the log-likelihood sensitivity diagnostic round_to : int, optional Number of decimal places to round the sensitivity values. Default is 3. Returns ------- psense_df : DataFrame DataFrame containing the prior and likelihood sensitivity values for each variable in the data. And a diagnosis column with the following values: - "prior-data conflict" if both prior and likelihood sensitivity are above threshold - "strong prior / weak likelihood" if the prior sensitivity is above threshold and the likelihood sensitivity is below the threshold - "-" otherwise Examples -------- .. ipython:: In [1]: from arviz_base import load_arviz_data ...: from arviz_stats import psense_summary ...: rugby = load_arviz_data("rugby") ...: psense_summary(rugby, var_names="atts") Notes ----- The diagnostic is computed by power-scaling either the prior or likelihood and determining the degree to which the posterior changes as described in [1]_. It uses Pareto-smoothed importance sampling to avoid refitting the model. References ---------- .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5 """ pssdp = psense( data, var_names=var_names, filter_vars=filter_vars, group="prior", sample_dims=sample_dims, coords=coords, alphas=alphas, group_var_names=prior_var_names, group_coords=prior_coords, ) pssdl = psense( data, var_names=var_names, filter_vars=filter_vars, group="likelihood", coords=coords, sample_dims=sample_dims, alphas=alphas, group_var_names=likelihood_var_names, group_coords=likelihood_coords, ) joined = xr.concat([pssdp, pssdl], dim="component").assign_coords( component=["prior", "likelihood"] ) psense_df = dataset_to_dataframe(joined, sample_dims=["component"]).T def _diagnose(row): if row["prior"] >= threshold and row["likelihood"] >= threshold: return "potential prior-data conflict" if row["prior"] > threshold > row["likelihood"]: return "potential strong prior / weak likelihood" return "✓" psense_df["diagnosis"] = psense_df.apply(_diagnose, axis=1) if "potential" in "".join(psense_df["diagnosis"]): _log.warning( "We detected potential issues. For more information on how to interpret the results, " "please check\n" "https://arviz-devs.github.io/EABM/Chapters/" "Sensitivity_checks.html#interpreting-sensitivity-diagnostics-summary\n" "or read original paper https://doi.org/10.1007/s11222-023-10366-5" ) return psense_df.round(round_to)
def power_scale_dataset(data, group, alphas, sample_dims, group_var_names, group_coords): """Resample posterior based on power-scaled weights. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the posterior and the log_likelihood and/or log_prior groups. group : str Group to resample. Either "prior" or "likelihood" alphas : tuple of float Lower and upper alpha values for power scaling. sample_dims : str or sequence of hashable Dimensions to reduce unless mapped to an aesthetic. group_var_names : str Name of the log-prior or log-likelihood variables to use. group_coords : dict Coordinates defining a subset over the group element for which to compute the sensitivity diagnostic. Returns ------- DataSet with resampled data. """ dt = convert_to_datatree(data) lower_w, upper_w = _get_power_scale_weights( dt, alphas, group=group, sample_dims=sample_dims, group_var_names=group_var_names, group_coords=group_coords, ) lower_w = lower_w.values.flatten() upper_w = upper_w.values.flatten() s_size = len(lower_w) idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims idxs_to_drop = set(idxs_to_drop).union( [ idx for idx in dt["posterior"].xindexes if any(dim in dt["posterior"][idx].dims for dim in sample_dims) ] ) resampled = [ extract( dt, group="posterior", sample_dims=sample_dims, num_samples=s_size, weights=weights, random_seed=42, resampling_method="stratified", ).drop_indexes(idxs_to_drop) for weights in (lower_w, upper_w) ] resampled.insert( 1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop) ) return xr.concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]]) def _get_power_scale_weights( dt, alphas=None, group=None, sample_dims=None, group_var_names=None, group_coords=None ): """Compute power scale weights.""" sample_dims = validate_dims(sample_dims) if group == "likelihood": group_draws = get_log_likelihood_dataset(dt, var_names=group_var_names) elif group == "prior": group_draws = get_log_prior(dt, var_names=group_var_names) else: raise ValueError("Value for `group` argument not recognized") if group_coords is not None: group_draws = group_draws.sel(group_coords) # we stack the different variables (if any) and dimensions in each variable (if any) # into a flat dimension "latent-obs_var", over which we sum afterwards. # Consequently, after this group_draws draws is a dataarray with only sample_dims as dims group_draws = group_draws.to_stacked_array("latent-obs_var", sample_dims=sample_dims).sum( "latent-obs_var" ) # calculate importance sampling weights for lower and upper alpha power-scaling lower_w = np.exp(group_draws.azstats.power_scale_lw(alpha=alphas[0], dim=sample_dims)) lower_w = lower_w / lower_w.sum(sample_dims) upper_w = np.exp(group_draws.azstats.power_scale_lw(alpha=alphas[1], dim=sample_dims)) upper_w = upper_w / upper_w.sum(sample_dims) return lower_w, upper_w