arviz_stats.rhat

Contents

arviz_stats.rhat#

arviz_stats.rhat(data, sample_dims=None, group='posterior', var_names=None, filter_vars=None, coords=None, method='rank', chain_axis=0, draw_axis=1)[source]#

Compute estimate of rank normalized split R-hat for a set of traces.

The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance between multiple chains to the variance within each chain. If convergence has been achieved, the between-chain and within-chain variances should be identical. To be most effective in detecting evidence for nonconvergence, each chain should have been initialized to starting values that are dispersed relative to the target distribution.

Parameters:
dataarray_like, xarray.DataArray, xarray.Dataset, xarray.DataTree, DataArrayGroupBy, DatasetGroupBy, or idata-like

Input data. It will have different pre-processing applied to it depending on its type:

  • array-like: call array layer within arviz-stats.

  • xarray object: apply dimension aware function to all relevant subsets

  • others: passed to arviz_base.convert_to_dataset

At least 2 posterior chains are needed to compute this diagnostic of one or more stochastic parameters.

sample_dimsiterable of hashable, optional

Dimensions to be considered sample dimensions and are to be reduced. Default rcParams["data.sample_dims"].

grouphashable, default “posterior”

Group on which to compute the ESS.

var_namesstr or list of str, optional

Names of the variables for which the Rhat should be computed.

filter_vars{None, “like”, “regex”}, default None
coordsdict, optional

Dictionary of dimension/index names to coordinate values defining a subset of the data for which to perform the computation.

methodstr, default “rank”

Valid methods are: - “rank” # recommended by Vehtari et al. (2021) - “split” - “folded” - “z_scale” - “identity”

chain_axis, draw_axisint, optional

Integer indicators of the axis that correspond to the chain and the draw dimension. chain_axis can be None.

Returns:
ndarray, xarray.DataArray, xarray.Dataset, xarray.DataTree

Requested Rhat summary of the provided input

See also

arviz.ess

Calculate estimate of the effective sample size (ess).

arviz.mcse

Calculate Markov Chain Standard Error statistic.

plot_forest

Forest plot to compare HDI intervals from a number of distributions.

Notes

The diagnostic is computed by:

\[\hat{R} = \sqrt{\frac{\hat{V}}{W}}\]

where \(W\) is the within-chain variance and \(\hat{V}\) is the posterior variance estimate for the pooled rank-traces. This is the potential scale reduction factor, which converges to unity when each of the traces is a sample from the target posterior. Values greater than one indicate that one or more chains have not yet converged.

Rank values are calculated over all the chains with scipy.stats.rankdata. Each chain is split in two and normalized with the z-transform following Vehtari et al. (2021).

References

  • Vehtari et al. (2021). Rank-normalization, folding, and localization: An improved Rhat for assessing convergence of MCMC. Bayesian analysis, 16(2):667-718.

  • Gelman et al. BDA3 (2013)

  • Brooks and Gelman (1998)

  • Gelman and Rubin (1992)

Examples

Calculate the R-hat using the default arguments:

In [1]: from arviz_base import load_arviz_data
   ...: import arviz_stats as azs
   ...: data = load_arviz_data('non_centered_eight')
   ...: azs.rhat(data)
   ...: 
Out[1]: 
<xarray.DataTree 'posterior'>
Group: /posterior
    Dimensions:  (school: 8)
    Coordinates:
      * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
    Data variables:
        mu       float64 8B 1.003
        theta_t  (school) float64 64B 1.0 1.001 0.9997 1.001 ... 1.004 0.9992 1.002
        tau      float64 8B 1.003
        theta    (school) float64 64B 1.003 0.9992 1.003 1.001 ... 1.002 1.001 1.003

Calculate the R-hat of some variables using the folded method:

In [2]: azs.rhat(data, var_names=["mu", "theta_t"], method="folded")
Out[2]: 
<xarray.DataTree 'posterior'>
Group: /posterior
    Dimensions:  (school: 8)
    Coordinates:
      * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
    Data variables:
        mu       float64 8B 0.9997
        theta_t  (school) float64 64B 1.0 1.001 0.9997 1.001 ... 1.004 0.9992 1.002