arviz_stats.wasserstein

Contents

arviz_stats.wasserstein#

arviz_stats.wasserstein(data1, data2, group='posterior', var_names=None, sample_dims=None, joint=True, num_samples=500, round_to='2g', random_seed=212480)[source]#

Compute the Wasserstein-1 distance.

The Wasserstein distance, also called the Earth mover’s distance or the optimal transport distance, is a similarity metric between two probability distributions [1].

Parameters:
data1, data2xarray.DataArray, xarray.Dataset, xarray.DataTree, or InferenceData
grouphashable, default “posterior”

Group on which to compute the Wasserstein distance.

var_namesstr or list of str, optional

Names of the variables for which the Wasserstein distance should be computed.

sample_dimsiterable of hashable, optional

Dimensions to be considered sample dimensions and are to be reduced. Default rcParams["data.sample_dims"].

jointbool, default True

Whether to compute Wasserstein distance for the joint distribution (True) or over the marginals (False)

num_samplesint

Number of samples to use for the distance calculation. Default is 500.

round_to: int or str, optional

If integer, number of decimal places to round the result. If string of the form ‘2g’ number of significant digits to round the result. Defaults to ‘2g’. Use None to return raw numbers.

random_seedint

Random seed for reproducibility. Use None for no seed.

Returns:
wasserstein_distancefloat

Notes

The computation is faster for the marginals (joint=False). This is equivalent to assume the marginals are independent, which usually is not the case. This function uses the scipy.stats.wasserstein_distance for the computation of the marginals and scipy.stats.wasserstein_distance_nd for the joint distribution.

References

Examples

Calculate the Wasserstein distance between the posterior distributions for the variable mu in the centered and non-centered eight schools models

In [1]: from arviz_stats import wasserstein
   ...: from arviz_base import load_arviz_data
   ...: data1 = load_arviz_data('centered_eight')
   ...: data2 = load_arviz_data('non_centered_eight')
   ...: wasserstein(data1, data2, var_names="mu")
   ...: 
Out[1]: 0.29