Using arviz_stats
array interface#
This tutorial covers how to use the arviz_stats
array interface for diagnosing and summarizing Bayesian modeling results stored as
NumPy arrays. It is aimed at advanced users and developers of other libraries, for example developers of probabilistic programming languages who want to incorporate sampling diagnostics into their library.
What is the “array interface”?#
The array interface is the base building block on top of which everything within arviz_stats
is built, and is always available.
When you install arviz_stats
as pip install arviz_stats
(instead of the recommended way shown in Installation) you get
a minimal package installed with NumPy and SciPy as the only dependencies and array_stats
as the way to interface with the functions of the library.
As the array interface does not depend on arviz_base
defaults are either hardcoded or not set, making some arguments that are optional when using the
top level functions or xarray interfaces required ones. Others like the axis
arguments do have defaults, but much like default axis
for NumPy functions,
you should never assume they’ll work for your specific case. You’ll notice in this tutorial we explicitly set the axis
arguments for all function calls.
See also
For an in depth explanation of the modules in arviz_stats
and its architecture, way beyond what is necessary for this tutorial,
see Architecture.
Importing the array interface#
The array interface is not a module but a Python class. It can be imported with:
from arviz_stats.base import array_stats
# you can also give an alias to the array_stats class such as
# from arviz_stats.base import array_stats as az
# then use `az.ess` and so on
MCMC diagnostics#
In MCMC there are two dimensions with special meaning: “chain” and “draw” so the array interface for such functions has two “axis” arguments, one to indicate which dimension represents the chain and another to indicate which represents the draw dimension.
# generate mock MCMC-like data
import numpy as np
rng = np.random.default_rng()
samples = rng.normal(size=(4, 100, 7))
array_stats.ess(samples, chain_axis=0, draw_axis=1)
array([459.34245009, 401.08972477, 329.3874491 , 401.4836714 ,
506.45514111, 460.83796322, 426.69631509])
axis = {"chain_axis": 0, "draw_axis": 1}
array_stats.rhat_nested(samples, (0, 0, 1, 1), **axis)
array([1.00178091, 1.00001307, 1.0187187 , 1.00321008, 1.00306606,
1.0018201 , 1.00551756])
It is also possible to use chain_axis=None
when there is no chain dimension. Some diagnostics like ess
or mcse
still work as shown in the example,
whereas others like rhat
make no sense when there aren’t multiple chains so using chain_axis=None
will always result in NaNs as output.
Similarly, attempting to compute ess
on an array with less than 4 draws would also output NaNs in the expected shape.
array_stats.mcse(samples, chain_axis=None, draw_axis=1, method="sd")
array([[0.07888795, 0.08384865, 0.07420566, 0.07939282, 0.05469291,
0.07176403, 0.05994337],
[0.04798386, 0.06194168, 0.09648721, 0.08752604, 0.08908106,
0.06760488, 0.05030621],
[0.0760319 , 0.06649435, 0.0543176 , 0.07703644, 0.06384178,
0.07591641, 0.08611206],
[0.09009632, 0.10274383, 0.05875394, 0.062174 , 0.06180665,
0.06932181, 0.07687072]])
array_stats.rhat(samples, chain_axis=None, draw_axis=1)
array([[nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan]])
Statistical summaries#
When computing statistical summaries we might one to reduce one or multiple dimensions, so all functions in the array interface have an axis
argument that
takes an integer, a sequence of integers or None
(which indicates all dimensions should be reduced). If some deviate from this behaviour, the docstring should indicate it and
an informative error message be printed out as is the case for autocorr
array_stats.hdi(samples, 0.8, axis=(0, 1))
array([[-1.36275886, 1.08142061],
[-1.32559377, 1.09291349],
[-1.13343236, 1.30187928],
[-1.17227971, 1.44383128],
[-1.44747829, 1.27070483],
[-1.29181151, 1.42380106],
[-1.56946813, 1.25501178]])
array_stats.kde(samples, axis=(0, 1))
(array([[-3.00429826, -2.99271058, -2.9811229 , ..., 2.893829 ,
2.90541668, 2.91700435],
[-3.18961408, -3.17801739, -3.16642071, ..., 2.71309886,
2.72469555, 2.73629223],
[-2.77803127, -2.7674738 , -2.75691633, ..., 2.59572217,
2.60627965, 2.61683712],
...,
[-3.83563002, -3.82229155, -3.80895308, ..., 2.95365043,
2.9669889 , 2.98032737],
[-3.08031093, -3.06870726, -3.05710359, ..., 2.82595522,
2.83755888, 2.84916255],
[-3.04816753, -3.03578764, -3.02340775, ..., 3.25319563,
3.26557552, 3.27795541]], shape=(7, 512)),
array([[0.02231199, 0.0223113 , 0.02230996, ..., 0.01410705, 0.01409952,
0.01409533],
[0.01899806, 0.01899454, 0.01898732, ..., 0.03057282, 0.03056991,
0.03056879],
[0.03233875, 0.03235778, 0.03239163, ..., 0.0322979 , 0.03225837,
0.03223739],
...,
[0.0059255 , 0.00591692, 0.00589968, ..., 0.01850032, 0.01846922,
0.01845192],
[0.02327452, 0.02328063, 0.02329106, ..., 0.02246641, 0.02244466,
0.02243197],
[0.02081782, 0.02084132, 0.02088549, ..., 0.01186477, 0.01187269,
0.01187576]], shape=(7, 512)),
array([0.26214729, 0.30198939, 0.3063812 , 0.32063124, 0.33786006,
0.33326872, 0.35510539]))
import traceback
try:
array_stats.autocorr(samples, axis=(0, 1))
except ValueError as err:
traceback.print_exception(err)
Traceback (most recent call last):
File "/tmp/ipykernel_31735/2391265193.py", line 4, in <module>
array_stats.autocorr(samples, axis=(0, 1))
File "/home/oriol/Documents/repos_oss/arviz-stats/src/arviz_stats/base/core.py", line 68, in autocorr
raise ValueError("Only integer values are allowed for `axis` in autocorr.")
ValueError: Only integer values are allowed for `axis` in autocorr.
Model comparison#
# generate mock pointwise log likelihood
from scipy.stats import norm
log_lik = norm.logpdf(samples, loc=0.2, scale=1.1)
log_weights, khats = array_stats.psislw(-log_lik, axis=(0, 1))
print(f"log_lik shape: {log_lik.shape}")
print(f"log_weights shape: {log_weights.shape}")
print(f"khats shape: {khats.shape}")
# TODO: call loo function with log_weights and khats as inputs
log_lik shape: (4, 100, 7)
log_weights shape: (7, 4, 100)
khats shape: (7,)
Note that the shape of log_weights
is not exactly the same as the shape of log_lik
. The dimensions on which the function acts are moved to the end.
For functions that reduce these dimensions, like the ones we have used so far or the khats
output, this makes no difference;
but for log_weights
it does. This is due to the fact that the array interface is one of the building blocks of the DataArray interface,
which uses xarray.apply_ufunc
. apply_ufunc
requires the dimensions the function works on and any dimension added to be the last ones.