This page was generated from examples/cd_clf_cifar10.ipynb.
Classifier drift detector on CIFAR-10¶
Method¶
The classifier-based drift detector simply tries to correctly classify instances from the reference data vs. the test set. If the classifier does not manage to significantly distinguish the reference data from the test set according to a chosen metric (defaults to the classifier accuracy), then no drift occurs. If it can, the test set is different from the reference data and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used to compute the drift metric. Note that a new classifier is trained for each test set or even each fold within the test set.
Backend¶
The method works with both the PyTorch and TensorFlow frameworks. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
Dataset¶
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
[1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from alibi_detect.cd import ClassifierDrift
from alibi_detect.utils.saving import save_detector, load_detector
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c
Load data¶
Original CIFAR-10 data:
[2]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
y_train = y_train.astype('int64').reshape(-1,)
y_test = y_test.astype('int64').reshape(-1,)
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
[3]:
corruptions = corruption_types_cifar10c()
print(corruptions)
['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur']
Let’s pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
[4]:
corruption = ['gaussian_noise', 'motion_blur', 'brightness', 'pixelate']
X_corr, y_corr = fetch_cifar10c(corruption=corruption, severity=5, return_X_y=True)
X_corr = X_corr.astype('float32') / 255
We split the original test set in a reference dataset and a dataset which should not be flagged as drift. We also split the corrupted data by corruption type:
[5]:
np.random.seed(0)
n_test = X_test.shape[0]
idx = np.random.choice(n_test, size=n_test // 2, replace=False)
idx_h0 = np.delete(np.arange(n_test), idx, axis=0)
X_ref,y_ref = X_test[idx], y_test[idx]
X_h0, y_h0 = X_test[idx_h0], y_test[idx_h0]
print(X_ref.shape, X_h0.shape)
(5000, 32, 32, 3) (5000, 32, 32, 3)
[6]:
n_corr = len(corruption)
X_c = [X_corr[i * n_test:(i + 1) * n_test] for i in range(n_corr)]
We can visualise the same instance for each corruption type:
[7]:
i = 6
n_test = X_test.shape[0]
plt.title('Original')
plt.axis('off')
plt.imshow(X_test[i])
plt.show()
for _ in range(len(corruption)):
plt.title(corruption[_])
plt.axis('off')
plt.imshow(X_corr[n_test * _+ i])
plt.show()
Detect drift with a TensorFlow classifier¶
Single fold¶
We use a simple classification model and try to distinguish between the reference data and the corrupted test sets. Initially we’ll use an accuracy threshold set at \(0.55\), use \(75\)% of the shuffled reference and test data for training and evaluate the detector on the remaining \(25\)%. We only train for 1 epoch.
[8]:
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input
tf.random.set_seed(0)
model = tf.keras.Sequential(
[
Input(shape=(32, 32, 3)),
Conv2D(8, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(16, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(32, 4, strides=2, padding='same', activation=tf.nn.relu),
Flatten(),
Dense(2, activation='softmax')
]
)
cd = ClassifierDrift(X_ref, model, threshold=.55, train_size=.75, epochs=1)
# we can also save/load an initialised detector
filepath = 'my_path' # change to directory where detector is saved
save_detector(cd, filepath)
cd = load_detector(filepath)
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
Let’s check whether the detector thinks drift occurred on the different test sets and time the prediction calls:
[9]:
from timeit import default_timer as timer
labels = ['No!', 'Yes!']
def make_predictions(cd, x_h0, x_corr, corruption, metric="accuracy"):
t = timer()
preds = cd.predict(x_h0)
dt = timer() - t
print('No corruption')
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print(f'{metric}: {preds["data"][metric]:.3f}')
print(f'Time (s) {dt:.3f}')
if isinstance(x_corr, list):
for x, c in zip(x_corr, corruption):
t = timer()
preds = cd.predict(x)
dt = timer() - t
print('')
print(f'Corruption type: {c}')
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print(f'{metric}: {preds["data"][metric]:.3f}')
print(f'Time (s) {dt:.3f}')
[10]:
make_predictions(cd, X_h0, X_c, corruption)
No corruption
Drift? No!
accuracy: 0.485
Time (s) 2.957
Corruption type: gaussian_noise
Drift? Yes!
accuracy: 0.990
Time (s) 2.263
Corruption type: motion_blur
Drift? Yes!
accuracy: 0.863
Time (s) 2.149
Corruption type: brightness
Drift? Yes!
accuracy: 0.898
Time (s) 2.127
Corruption type: pixelate
Drift? Yes!
accuracy: 0.992
Time (s) 2.167
As expected, drift was only detected on the corrupted datasets and the classifier could easily distinguish the corrupted from the reference data.
Use all the available data via cross-validation¶
So far we’ve only used \(25\)% of the data to detect the drift since \(75\)% is used for training purposes. At the cost of additional training time we can however leverage all the data via stratified cross-validation. We just need to set the number of folds and keep everything else the same. So for each test set n_folds
models are trained, and the out-of-fold predictions combined for the final drift metric (in this case the accuracy):
[11]:
cd = ClassifierDrift(X_ref, model, threshold=.55, n_folds=5, epochs=1)
WARNING:alibi_detect.cd.base:Both `n_folds` and `train_size` specified. By default `n_folds` is used.
[12]:
make_predictions(cd, X_h0, X_c, corruption)
No corruption
Drift? No!
accuracy: 0.500
Time (s) 7.459
Corruption type: gaussian_noise
Drift? Yes!
accuracy: 0.991
Time (s) 10.255
Corruption type: motion_blur
Drift? Yes!
accuracy: 0.864
Time (s) 9.851
Corruption type: brightness
Drift? Yes!
accuracy: 0.904
Time (s) 9.978
Corruption type: pixelate
Drift? Yes!
accuracy: 0.990
Time (s) 10.290
Detect drift with PyTorch classifier¶
We can do the same with a PyTorch instead of a TensorFlow model:
[13]:
import torch
import torch.nn as nn
# set random seed and device
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# define classifier model
model = nn.Sequential(
nn.Conv2d(3, 8, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(8, 16, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(16, 32, 4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(128, 2)
)
Since our PyTorch encoder expects the images in a (batch size, channels, height, width) format, we transpose the data. Note that this step could also be passed to the drift detector via the preprocess_fn kwarg:
[14]:
def permute_c(x):
return np.transpose(x.astype(np.float32), (0, 3, 1, 2))
X_ref_pt = permute_c(X_ref)
X_h0_pt = permute_c(X_h0)
X_c_pt = [permute_c(xc) for xc in X_c]
print(X_ref_pt.shape, X_h0_pt.shape, X_c_pt[0].shape)
(5000, 3, 32, 32) (5000, 3, 32, 32) (10000, 3, 32, 32)
[15]:
# we again use the cross-validation approach
cd = ClassifierDrift(X_ref_pt, model, backend='pytorch', threshold=.55, n_folds=5, epochs=1)
WARNING:alibi_detect.cd.base:Both `n_folds` and `train_size` specified. By default `n_folds` is used.
[16]:
make_predictions(cd, X_h0_pt, X_c_pt, corruption)
No corruption
Drift? No!
accuracy: 0.500
Time (s) 6.297
Corruption type: gaussian_noise
Drift? Yes!
accuracy: 0.989
Time (s) 9.344
Corruption type: motion_blur
Drift? Yes!
accuracy: 0.814
Time (s) 9.352
Corruption type: brightness
Drift? Yes!
accuracy: 0.893
Time (s) 9.228
Corruption type: pixelate
Drift? Yes!
accuracy: 0.971
Time (s) 9.128