arviz_stats.kl_divergence#
- arviz_stats.kl_divergence(data1, data2, group='posterior', var_names=None, sample_dims=None, num_samples=500, round_to='2g', random_seed=212480)[source]#
Compute the Kullback-Leibler (KL) divergence.
The KL-divergence is a measure of how different two probability distributions are. It represents how much extra uncertainty are we introducing when we use one distribution to approximate another. The KL-divergence is not symmetric, thus changing the order of the data1 and data2 arguments will change the result.
For details of the approximation used to the compute the KL-divergence see [1].
- Parameters:
- data1, data2
xarray.DataArray
,xarray.Dataset
,xarray.DataTree
, orInferenceData
- group
hashable
, default “posterior” Group on which to compute the kl-divergence.
- var_names
str
orlist
ofstr
, optional Names of the variables for which the KL-divergence should be computed.
- sample_dimsiterable of
hashable
, optional Dimensions to be considered sample dimensions and are to be reduced. Default
rcParams["data.sample_dims"]
.- num_samples
int
Number of samples to use for the distance calculation. Default is 500.
- round_to: int or str, optional
If integer, number of decimal places to round the result. If string of the form ‘2g’ number of significant digits to round the result. Defaults to ‘2g’. Use None to return raw numbers.
- random_seed
int
Random seed for reproducibility. Use None for no seed.
- data1, data2
- Returns:
- KL-divergence
float
- KL-divergence
References
[1]F. Perez-Cruz, Kullback-Leibler divergence estimation of continuous distributions IEEE International Symposium on Information Theory. (2008) https://doi.org/10.1109/ISIT.2008.4595271. preprint https://www.tsc.uc3m.es/~fernando/bare_conf3.pdf
Examples
Calculate the KL-divergence between the posterior distributions for the variable mu in the centered and non-centered eight schools models
In [1]: from arviz_stats import kl_divergence ...: from arviz_base import load_arviz_data ...: data1 = load_arviz_data('centered_eight') ...: data2 = load_arviz_data('non_centered_eight') ...: kl_divergence(data1, data2, var_names="mu") ...: Out[1]: 1.1