This page was generated from doc/source/methods/classifierdrift.ipynb.
Classifier¶
Overview¶
The classifier-based drift detector Lopez-Paz and Oquab, 2017 simply tries to correctly distinguish instances from the reference set vs. the test set. The classifier is trained to output the probability that a given instance belongs to the test set. If the probabilities it assigns to unseen tests instances are significantly higher (as determined by a Kolmogorov-Smirnoff test) to those it assigns to unseen reference instances then the test set must differ from the reference set and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used for the significance test. Note that a new classifier is trained for each test set or even each fold within the test set.
Usage¶
Initialize¶
Arguments:
x_ref
: Data used as reference distribution.model
: Classification model used for drift detection. Both TensorFlow and PyTorch models are supported.
Keyword arguments:
backend
: Specify the backend (tensorflow or pytorch). This depends on the framework of themodel
. Defaults to tensorflow.p_val
: p-value threshold used for the significance of the test.preprocess_x_ref
: Whether to already apply the (optional) preprocessing step to the reference data at initialization and store the preprocessed data. Dependent on the preprocessing step, this can reduce the computation time for the predict step significantly, especially when the reference dataset is large. Defaults to True. It is possible that it needs to be set to False if the preprocessing step requires statistics from both the reference and test data, such as the mean or standard deviation.update_x_ref
: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals {‘last’: N} while for reservoir sampling {‘reservoir_sampling’: N} is passed.preprocess_fn
: Function to preprocess the data before computing the data drift metrics.preds_type
: Whether the model outputs ‘probs’ or ‘logits’.binarize_preds
: Whether to test for discrepency on soft (e.g. probs/logits) model predictions directly with a K-S test or binarise to 0-1 prediction errors and apply a binomial test.train_size
: Optional fraction (float between 0 and 1) of the dataset used to train the classifier. The drift is detected on 1 - train_size. Cannot be used in combination withn_folds
.n_folds
: Optional number of stratified folds used for training. The model preds are then calculated on all the out-of-fold predictions. This allows to leverage all the reference and test data for drift detection at the expense of longer computation. If bothtrain_size
andn_folds
are specified,n_folds
is prioritized.seed
: Optional random seed for fold selection.optimizer
: Optimizer used during training of the classifier. Fromtorch.optim
for PyTorch andtf.keras.optimizers
for TensorFlow.learning_rate
: Learning rate for the optimizer.batch_size
: Batch size used during training of the classifier.epochs
: Number of training epochs for the classifier. Applies to each fold ifn_folds
is specified.verbose
: Verbosity level during the training of the classifier. 0 is silent, 1 a progress bar and 2 prints the statistics after each epoch.train_kwargs
: Optional additional kwargs for model.fit() when fitting the classifier for TensorFlow or for the built-in PyTorch trainer function (from alibi_detect.models.pytorch import trainer
).data_type
: Optionally specify the data type (e.g. tabular, image or time-series). Added to metadata.
Additional TensorFlow keyword arguments:
compile_kwargs
: Optional additional kwargs for model.compile() when compiling the classifier.
Additional PyTorch keyword arguments:
device
: cuda or gpu to use the GPU and cpu for the CPU. If the device is not specified, the detector will try to leverage the GPU if possible and otherwise fall back on CPU.
Initialized TensorFlow drift detector example:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input
from alibi_detect.cd import ClassifierDrift
model = tf.keras.Sequential(
[
Input(shape=(32, 32, 3)),
Conv2D(8, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(16, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(32, 4, strides=2, padding='same', activation=tf.nn.relu),
Flatten(),
Dense(2, activation='softmax')
]
)
cd = ClassifierDrift(x_ref, model, p_val=.05, preds_type='probs', n_folds=5, epochs=2)
A similar detector using PyTorch:
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(3, 8, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(8, 16, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(16, 32, 4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(128, 2)
)
cd = ClassifierDrift(x_ref, model, backend='pytorch', p_val=.05, preds_type='logits')
Detect Drift¶
We detect data drift by simply calling predict
on a batch of instances x
. return_p_val
equal to True will also return the p-value of the test and return_distance
equal to True will return a notion of strength of the drift.
The prediction takes the form of a dictionary with meta
and data
keys. meta
contains the detector’s metadata while data
is also a dictionary which contains the actual predictions stored in the following keys:
is_drift
: 1 if the sample tested has drifted from the reference data and 0 otherwise.threshold
: the user-defined threshold defining the significance of the testp_val
: the p-value of the test ifreturn_p_val
equals True.distance
: a notion of strength of the drift ifreturn_distance
equals True. Equal to the K-S test statistic assumingbinarize_preds
equals False or the relative error reduction over the baseline error expected under the null ifbinarize_preds
equals True.
preds = cd.predict(x)
Saving and loading¶
The drift detectors can be saved and loaded in the same way as other detectors:
from alibi_detect.utils.saving import save_detector, load_detector
filepath = 'my_path'
save_detector(cd, filepath)
cd = load_detector(filepath)
Currently on the TensorFlow backend is supported for save_detector
and load_detector
. Adding PyTorch support is a near term priority.