arviz_stats.wasserstein#
- arviz_stats.wasserstein(data1, data2, group='posterior', var_names=None, sample_dims=None, joint=True, num_samples=500, round_to='2g', random_seed=212480)[source]#
Compute the Wasserstein-1 distance.
The Wasserstein distance, also called the Earth mover’s distance or the optimal transport distance, is a similarity metric between two probability distributions [1].
- Parameters:
- data1, data2
xarray.DataArray
,xarray.Dataset
,xarray.DataTree
, orInferenceData
- group
hashable
, default “posterior” Group on which to compute the Wasserstein distance.
- var_names
str
orlist
ofstr
, optional Names of the variables for which the Wasserstein distance should be computed.
- sample_dimsiterable of
hashable
, optional Dimensions to be considered sample dimensions and are to be reduced. Default
rcParams["data.sample_dims"]
.- jointbool, default
True
Whether to compute Wasserstein distance for the joint distribution (True) or over the marginals (False)
- num_samples
int
Number of samples to use for the distance calculation. Default is 500.
- round_to: int or str, optional
If integer, number of decimal places to round the result. If string of the form ‘2g’ number of significant digits to round the result. Defaults to ‘2g’. Use None to return raw numbers.
- random_seed
int
Random seed for reproducibility. Use None for no seed.
- data1, data2
- Returns:
- wasserstein_distance
float
- wasserstein_distance
Notes
The computation is faster for the marginals (joint=False). This is equivalent to assume the marginals are independent, which usually is not the case. This function uses the
scipy.stats.wasserstein_distance
for the computation of the marginals andscipy.stats.wasserstein_distance_nd
for the joint distribution.References
[1]“Wasserstein metric”, https://en.wikipedia.org/wiki/Wasserstein_metric
Examples
Calculate the Wasserstein distance between the posterior distributions for the variable mu in the centered and non-centered eight schools models
In [1]: from arviz_stats import wasserstein ...: from arviz_base import load_arviz_data ...: data1 = load_arviz_data('centered_eight') ...: data2 = load_arviz_data('non_centered_eight') ...: wasserstein(data1, data2, var_names="mu") ...: Out[1]: 0.29