This page was generated from doc/source/methods/mmddrift.ipynb.
Maximum Mean Discrepancy¶
Overview¶
The Maximum Mean Discrepancy (MMD) detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions p and q based on the mean embeddings \(\mu_{p}\) and \(\mu_{q}\) in a reproducing kernel Hilbert space \(F\):
We can compute unbiased estimates of \(MMD^2\) from the samples of the 2 distributions after applying the kernel trick. We use by default a radial basis function kernel, but users are free to pass their own kernel of preference to the detector. We obtain a \(p\)-value via a permutation test on the values of \(MMD^2\).
For high-dimensional data, we typically want to reduce the dimensionality before computing the permutation test. Following suggestions in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, we incorporate Untrained AutoEncoders (UAE), black-box shift detection using the classifier’s softmax outputs (BBSDs) and PCA as out-of-the box preprocessing methods. Preprocessing methods which do not rely on the classifier will usually pick up drift in the input data, while BBSDs focuses on label shift.
Usage¶
Initialize¶
Parameters:
p_val
: p-value used for significance of the permutation test.X_ref
: Data used as reference distribution.update_X_ref
: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals {‘last’: N} while for reservoir sampling {‘reservoir_sampling’: N} is passed.preprocess_fn
: Function to preprocess the data before computing the data drift metrics. Typically a dimensionality reduction technique. The out-of-the box methods UAE, BBSDs and PCA are illustrated in the example notebook.preprocess_kwargs
: Keyword arguments forpreprocess_fn
. Again see the notebook for concrete examples.kernel
: Kernel function used when computing the MMD. Defaults to a Gaussian kernel.kernel_kwargs
: Keyword arguments for the kernel function. For the Gaussian kernel this is the kernel bandwidthsigma
. We can also sum over a number of different kernel bandwidths.sigma
then becomes an array with different values. Ifsigma
is not specified, the detector will infer it by computing the pairwise distances between each of the instances in the 2 samples and setsigma
to the median distance.n_permutations
: Number of permutations used in the permutation test.chunk_size
: Used to optionally compute the MMD between the 2 samples in chunks using dask to avoid potential out-of-memory errors. In terms of speed, the optimal chunk size is application and hardware dependent, so it is often worth to test a few different values, including None. None means that the computation is done in-memory in NumPy.data_type
: can specify data type added to metadata. E.g. ‘tabular’ or ‘image’.
Initialized drift detector example:
from alibi_detect.cd import MMDDrift
from alibi_detect.cd.preprocess import uae # Untrained AutoEncoder
encoder_net = tf.keras.Sequential(
[
InputLayer(input_shape=(32, 32, 3)),
Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),
Flatten(),
Dense(32,)
]
)
cd = MMDDrift(
p_val=.05,
X_ref=X_ref,
preprocess_fn=uae,
preprocess_kwargs={'encoder_net': encoder_net, 'batch_size': 128},
kernel=gaussian_kernel,
kernel_kwargs={'sigma': np.array([.5, 1., 5.])},
chunk_size=1000,
n_permutations=1000
)
Detect Drift¶
We detect data drift by simply calling predict
on a batch of instances X
. We can return the p-value of the permutation test by setting return_p_val
to True.
The prediction takes the form of a dictionary with meta
and data
keys. meta
contains the detector’s metadata while data
is also a dictionary which contains the actual predictions stored in the following keys:
is_drift
: 1 if the sample tested has drifted from the reference data and 0 otherwise.p_val
: contains the p-value ifreturn_p_val
equals True.
preds_drift = cd.predict(X, return_p_val=True)