arviz_stats.kl_divergence

Contents

arviz_stats.kl_divergence#

arviz_stats.kl_divergence(data1, data2, group='posterior', var_names=None, sample_dims=None, num_samples=500, round_to='2g', random_seed=212480)[source]#

Compute the Kullback-Leibler (KL) divergence.

The KL-divergence is a measure of how different two probability distributions are. It represents how much extra uncertainty are we introducing when we use one distribution to approximate another. The KL-divergence is not symmetric, thus changing the order of the data1 and data2 arguments will change the result.

For details of the approximation used to the compute the KL-divergence see [1].

Parameters:
data1, data2xarray.DataArray, xarray.Dataset, xarray.DataTree, or InferenceData
grouphashable, default “posterior”

Group on which to compute the kl-divergence.

var_namesstr or list of str, optional

Names of the variables for which the KL-divergence should be computed.

sample_dimsiterable of hashable, optional

Dimensions to be considered sample dimensions and are to be reduced. Default rcParams["data.sample_dims"].

num_samplesint

Number of samples to use for the distance calculation. Default is 500.

round_to: int or str, optional

If integer, number of decimal places to round the result. If string of the form ‘2g’ number of significant digits to round the result. Defaults to ‘2g’. Use None to return raw numbers.

random_seedint

Random seed for reproducibility. Use None for no seed.

Returns:
KL-divergencefloat

References

[1]

F. Perez-Cruz, Kullback-Leibler divergence estimation of continuous distributions IEEE International Symposium on Information Theory. (2008) https://doi.org/10.1109/ISIT.2008.4595271. preprint https://www.tsc.uc3m.es/~fernando/bare_conf3.pdf

Examples

Calculate the KL-divergence between the posterior distributions for the variable mu in the centered and non-centered eight schools models

In [1]: from arviz_stats import kl_divergence
   ...: from arviz_base import load_arviz_data
   ...: data1 = load_arviz_data('centered_eight')
   ...: data2 = load_arviz_data('non_centered_eight')
   ...: kl_divergence(data1, data2, var_names="mu")
   ...: 
Out[1]: 1.1