This page was generated from examples/cd_text_imdb.ipynb.
Text drift detection on IMDB movie reviews¶
Method¶
We detect drift on text data using both the Maximum Mean Discrepancy and Kolmogorov-Smirnov (K-S) detectors. In this example notebook we will focus on detecting covariate shift \(\Delta p(x)\) as detecting predicted label distribution drift does not differ from other modalities (check K-S and MMD drift on CIFAR-10).
It becomes however a little bit more involved when we want to pick up input data drift \(\Delta p(x)\). When we deal with tabular or image data, we can either directly apply the two sample hypothesis test on the input or do the test after a preprocessing step with for instance an Untrained AutoEncoder (UAE) as proposed in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift. It is not as straightforward when dealing with text, both in string or tokenized format as they don’t directly represent the semantics of the input.
As a result, we extract (contextual) embeddings for the text and detect drift on those. This procedure has a significant impact on the type of drift we detect. Strictly speaking we are not detecting \(\Delta p(x)\) anymore since the whole training procedure (objective function, training data etc) for the (pre)trained embeddings has an impact on the embeddings we extract.
The library contains functionality to leverage pre-trained embeddings from HuggingFace’s transformer package but also allows you to easily use your own embeddings of choice. Both options are illustrated with examples in this notebook.
Dataset¶
Binary sentiment classification dataset containing \(25,000\) movie reviews for training and \(25,000\) for testing. We make use from the nlp package for the data, which you can install via pip: pip install nlp
[1]:
import nlp
import numpy as np
import os
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer, BertConfig
from alibi_detect.cd import KSDrift, MMDDrift
from alibi_detect.cd.preprocess import UAE
from alibi_detect.models.embedding import TransformerEmbedding
from alibi_detect.utils.saving import save_detector, load_detector
Load model¶
[2]:
model_name = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/avl/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391
INFO:transformers.configuration_utils:Model config BertConfig {
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 28996
}
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/avl/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
Load data¶
[3]:
def load_dataset(dataset: str, split: str = 'test'):
data = nlp.load_dataset(dataset)
X, y = [], []
for x in data[split]:
X.append(x['text'])
y.append(x['label'])
X = np.array(X)
y = np.array(y)
return X, y
[4]:
X, y = load_dataset('imdb', split='train')
print(X.shape, y.shape)
INFO:nlp.load:Checking /home/avl/.cache/huggingface/datasets/d3b7716978cb901261e59327d43b04c52d6d29e50eeac39bea0816865a584081.7c39fd6270c5ee55bcf2e4de23af77ef299e0df65be3f3e84454dcef7175844a.py for additional imports.
INFO:filelock:Lock 140684143876368 acquired on /home/avl/.cache/huggingface/datasets/d3b7716978cb901261e59327d43b04c52d6d29e50eeac39bea0816865a584081.7c39fd6270c5ee55bcf2e4de23af77ef299e0df65be3f3e84454dcef7175844a.py.lock
INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/imdb/imdb.py at /home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/nlp/datasets/imdb
INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/imdb/imdb.py at /home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/nlp/datasets/imdb/76cdbd7249ea3548c928bbf304258dab44d09cd3638d9da8d42480d1d1be3743
INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/imdb/imdb.py to /home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/nlp/datasets/imdb/76cdbd7249ea3548c928bbf304258dab44d09cd3638d9da8d42480d1d1be3743/imdb.py
INFO:nlp.load:Updating dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/imdb/dataset_infos.json to /home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/nlp/datasets/imdb/76cdbd7249ea3548c928bbf304258dab44d09cd3638d9da8d42480d1d1be3743/dataset_infos.json
INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/imdb/imdb.py at /home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/nlp/datasets/imdb/76cdbd7249ea3548c928bbf304258dab44d09cd3638d9da8d42480d1d1be3743/imdb.json
INFO:filelock:Lock 140684143876368 released on /home/avl/.cache/huggingface/datasets/d3b7716978cb901261e59327d43b04c52d6d29e50eeac39bea0816865a584081.7c39fd6270c5ee55bcf2e4de23af77ef299e0df65be3f3e84454dcef7175844a.py.lock
INFO:nlp.builder:No config specified, defaulting to first: imdb/plain_text
INFO:nlp.info:Loading Dataset Infos from /home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/nlp/datasets/imdb/76cdbd7249ea3548c928bbf304258dab44d09cd3638d9da8d42480d1d1be3743
INFO:nlp.builder:Overwrite dataset info from restored data version.
INFO:nlp.info:Loading Dataset info from /home/avl/.cache/huggingface/datasets/imdb/plain_text/1.0.0
INFO:nlp.builder:Reusing dataset imdb (/home/avl/.cache/huggingface/datasets/imdb/plain_text/1.0.0)
INFO:nlp.builder:Constructing Dataset for split None, from /home/avl/.cache/huggingface/datasets/imdb/plain_text/1.0.0
(25000,) (25000,)
Let’s take a look at respectively a negative and positive review:
[5]:
labels = ['Negative', 'Positive']
print(labels[y[-1]])
print(X[-1])
Negative
This is one of the dumbest films, I've ever seen. It rips off nearly ever type of thriller and manages to make a mess of them all.<br /><br />There's not a single good line or character in the whole mess. If there was a plot, it was an afterthought and as far as acting goes, there's nothing good to say so Ill say nothing. I honestly cant understand how this type of nonsense gets produced and actually released, does somebody somewhere not at some stage think, 'Oh my god this really is a load of shite' and call it a day. Its crap like this that has people downloading illegally, the trailer looks like a completely different film, at least if you have download it, you haven't wasted your time or money Don't waste your time, this is painful.
[6]:
print(labels[y[2]])
print(X[2])
Positive
Brilliant over-acting by Lesley Ann Warren. Best dramatic hobo lady I have ever seen, and love scenes in clothes warehouse are second to none. The corn on face is a classic, as good as anything in Blazing Saddles. The take on lawyers is also superb. After being accused of being a turncoat, selling out his boss, and being dishonest the lawyer of Pepto Bolt shrugs indifferently "I'm a lawyer" he says. Three funny words. Jeffrey Tambor, a favorite from the later Larry Sanders show, is fantastic here too as a mad millionaire who wants to crush the ghetto. His character is more malevolent than usual. The hospital scene, and the scene where the homeless invade a demolition site, are all-time classics. Look for the legs scene and the two big diggers fighting (one bleeds). This movie gets better each time I see it (which is quite often).
We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the statistical test. We also create imbalanced datasets and inject selected words in the reference set.
[7]:
def random_sample(X: np.ndarray, y: np.ndarray, proba_zero: float, n: int):
if len(y.shape) == 1:
idx_0 = np.where(y == 0)[0]
idx_1 = np.where(y == 1)[0]
else:
idx_0 = np.where(y[:, 0] == 1)[0]
idx_1 = np.where(y[:, 1] == 1)[0]
n_0, n_1 = int(n * proba_zero), int(n * (1 - proba_zero))
idx_0_out = np.random.choice(idx_0, n_0, replace=False)
idx_1_out = np.random.choice(idx_1, n_1, replace=False)
X_out = np.concatenate([X[idx_0_out], X[idx_1_out]])
y_out = np.concatenate([y[idx_0_out], y[idx_1_out]])
return X_out, y_out
def padding_last(x: np.ndarray, seq_len: int) -> np.ndarray:
try: # try not to replace padding token
last_token = np.where(x == 0)[0][0]
except: # no padding
last_token = seq_len - 1
return 1, last_token
def padding_first(x: np.ndarray, seq_len: int) -> np.ndarray:
try: # try not to replace padding token
first_token = np.where(x == 0)[0][-1] + 2
except: # no padding
first_token = 0
return first_token, seq_len - 1
def inject_word(token: int, X: np.ndarray, perc_chg: float, padding: str = 'last'):
seq_len = X.shape[1]
n_chg = int(perc_chg * .01 * seq_len)
X_cp = X.copy()
for _ in range(X.shape[0]):
if padding == 'last':
first_token, last_token = padding_last(X_cp[_, :], seq_len)
else:
first_token, last_token = padding_first(X_cp[_, :], seq_len)
if last_token <= n_chg:
choice_len = seq_len
else:
choice_len = last_token
idx = np.random.choice(np.arange(first_token, choice_len), n_chg, replace=False)
X_cp[_, idx] = token
return X_cp
Reference, H0 and imbalanced data:
[8]:
# frac_zero = fraction with label 0 (=negative sentiment)
n_sample = 1000
X_ref = random_sample(X, y, proba_zero=.5, n=n_sample)[0]
X_h0 = random_sample(X, y, proba_zero=.5, n=n_sample)[0]
n_imb = [.1, .9]
X_imb = {_: random_sample(X, y, proba_zero=_, n=n_sample)[0] for _ in n_imb}
Inject words in reference data:
[9]:
words = ['fantastic', 'good', 'bad', 'horrible']
perc_chg = [1., 5.] # % of tokens to change in an instance
words_tf = tokenizer.encode(words, return_tensors='tf')
words_tf = list(words_tf.numpy()[0, 1:-1])
max_len = 100
tokens = tokenizer.batch_encode_plus(X_ref, pad_to_max_length=True,
max_length=max_len, return_tensors='tf')
X_word = {}
for i, w in enumerate(words_tf):
X_word[words[i]] = {}
for p in perc_chg:
x = inject_word(w, tokens['input_ids'].numpy(), p)
dec = tokenizer.batch_decode(x, **dict(skip_special_tokens=True))
X_word[words[i]][p] = np.array(dec)
Preprocessing¶
First we need to specify the type of embedding we want to extract from the BERT model. We can extract embeddings from the …
pooler_output: Last layer hidden-state of the first token of the sequence (classification token; CLS) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training. Note: his output is usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence.
last_hidden_state: Sequence of hidden-states at the output of the last layer of the model, averaged over the tokens.
hidden_state: Hidden states of the model at the output of each layer, averaged over the tokens.
hidden_state_cls: See hidden_state but use the CLS token output.
If hidden_state or hidden_state_cls is used as embedding type, you also need to pass the layer numbers used to extract the embedding from. As an example we extract embeddings from the last 8 hidden states.
[10]:
emb_type = 'hidden_state'
n_layers = 8
layers = [-_ for _ in range(1, n_layers + 1)]
Embedding = TransformerEmbedding(model_name, emb_type, layers)
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/avl/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391
INFO:transformers.configuration_utils:Model config BertConfig {
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"output_hidden_states": true,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 28996
}
INFO:transformers.modeling_tf_utils:loading weights file https://cdn.huggingface.co/bert-base-cased-tf_model.h5 from cache at /home/avl/.cache/torch/transformers/17e64dc7dc200314bc70dd8198010773501bcabb65a493c1ae7183b8c9a5b1ff.908e74db1113031d6827eb22808cf370b0aeded6e6ac20d0f07af0a334e195cc.h5
INFO:transformers.modeling_tf_utils:Layers from pretrained model not used in TFBertModel: ['nsp___cls', 'mlm___cls']
Let’s check what an embedding looks like:
[11]:
tokens = tokenizer.batch_encode_plus(X[:5], pad_to_max_length=True,
max_length=max_len, return_tensors='tf')
emb = Embedding(tokens)
print(emb.shape)
(5, 768)
So the BERT model’s embedding space used by the drift detector consists of a \(768\)-dimensional vector for each instance. We will therefore first apply a dimensionality reduction step with an Untrained AutoEncoder (UAE) before conducting the statistical hypothesis test. We use the embedding model as the input for the UAE which then projects the embedding on a lower dimensional space.
[12]:
tf.random.set_seed(0)
[13]:
enc_dim = 32
shape = (emb.shape[1],)
uae = UAE(input_layer=Embedding, shape=shape, enc_dim=enc_dim)
Let’s test this again:
[14]:
emb_uae = uae(tokens)
print(emb_uae.shape)
(5, 32)
K-S detector¶
Initialize¶
We proceed to initialize the drift detector. From here on the detector works the same as for other modalities such as images. Please check the images example or the K-S detector documentation for more information about each of the possible parameters.
[15]:
# define preprocessing step parameters
preprocess_kwargs = {
'model': uae,
'tokenizer': tokenizer,
'max_len': max_len,
'batch_size': 32
}
cd = KSDrift(
p_val=.05,
X_ref=X_ref, # reference data to test against
preprocess_X_ref=True, # store preprocessed X_ref for future predict calls
preprocess_kwargs=preprocess_kwargs
)
The reference data is already preprocessed and stored to save time at each predict call:
[16]:
assert cd.X_ref.shape[1] == enc_dim
Saving and loading detectors is straightforward. In order to keep the preprocessing flexibility we cannot make strong assumptions and need to pass the optional preprocess_fn
and/or preprocess_kwargs
to the loading function. This allows to have preprocessing steps defined in other frameworks such as PyTorch or scikit-learn.
[17]:
filepath = 'my_path'
save_detector(cd, filepath)
cd = load_detector(filepath)
WARNING:alibi_detect.utils.saving:Directory my_path does not exist and is now created.
Detect drift¶
Let’s first check if drift occurs on a similar sample from the training set as the reference data.
[18]:
preds_h0 = cd.predict(X_h0, return_p_val=True)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
Drift? No!
p-value: [0.79439443 0.93558097 0.60991895 0.82795686 0.7590978 0.5360543
0.79439443 0.99365413 0.50035924 0.93558097 0.722555 0.26338065
0.06155144 0.14833806 0.5726548 0.99365413 0.96887016 0.64755726
0.6852314 0.40047103 0.03778438 0.50035924 0.46576622 0.9882611
0.85929435 0.06155144 0.64755726 0.28769323 0.12050407 0.5360543
0.50035924 0.05464633]
Detect drift on imbalanced and perturbed datasets:
[19]:
for k, v in X_imb.items():
preds = cd.predict(v, return_p_val=True)
print('% negative sentiment {}'.format(k * 100))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
% negative sentiment 10.0
Drift? Yes!
p-value: [1.99518353e-01 3.40991944e-01 7.94394433e-01 5.00359237e-01
5.00359237e-01 5.72654784e-01 3.40991944e-01 2.92505771e-02
6.91903234e-02 9.99870896e-01 4.00471032e-01 9.13475513e-01
8.69054198e-02 1.08282514e-01 4.00471032e-01 7.59097815e-01
4.65766221e-01 4.28151786e-02 1.08282514e-01 9.68870163e-01
8.87938619e-01 2.87693232e-01 8.59294355e-01 3.13561678e-01
9.80161786e-01 3.32311448e-03 9.68870163e-01 7.26078229e-04
3.40991944e-01 9.69783217e-03 6.09918952e-01 4.00471032e-01]
% negative sentiment 90.0
Drift? Yes!
p-value: [2.0958899e-01 1.6769221e-01 4.6035737e-01 8.2974398e-01 5.1026817e-02
3.6845976e-04 1.3959650e-02 3.4691233e-03 8.4126943e-01 2.1410163e-02
3.8667390e-01 2.3008208e-01 2.5486696e-01 3.5239363e-01 5.4717654e-01
2.0238158e-01 9.8914808e-01 3.1717645e-04 1.3743564e-01 7.9330452e-02
8.4309745e-01 1.8943239e-02 9.8286802e-01 6.8996549e-01 9.8215282e-01
1.5109187e-01 8.3644027e-01 3.2119218e-01 1.3976933e-01 4.5221341e-01
1.5445707e-02 4.0925053e-01]
[20]:
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v, return_p_val=True)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
Word: fantastic -- % perturbed: 1.0
Drift? Yes!
p-value: [8.59294355e-01 4.84188050e-02 2.87693232e-01 8.87938619e-01
2.87693232e-01 3.77843790e-02 2.56591532e-02 5.72654784e-01
3.40991944e-01 9.98707950e-01 7.59097815e-01 4.00471032e-01
3.50604125e-04 8.87938619e-01 9.99870896e-01 9.99543309e-01
8.87938619e-01 7.94394433e-01 8.69054198e-02 8.27956855e-01
9.13475513e-01 3.69972497e-01 3.77843790e-02 6.85231388e-01
2.63380647e-01 1.33834302e-01 7.22554982e-01 9.54058170e-01
1.20504074e-01 2.56591532e-02 9.88261104e-01 8.27956855e-01]
Word: fantastic -- % perturbed: 5.0
Drift? Yes!
p-value: [1.0324168e-03 2.7371673e-15 1.0030026e-10 8.2795686e-01 5.4185481e-08
1.1215132e-18 2.5042721e-10 1.4893160e-02 1.5452230e-12 2.4072561e-04
3.3278044e-02 1.0030026e-10 0.0000000e+00 8.6666952e-04 8.2795686e-01
1.2274054e-03 3.6997250e-01 2.9250577e-02 3.2487294e-19 4.2815179e-02
1.6396579e-04 5.7140245e-15 6.0547863e-20 2.2463709e-02 1.8834284e-17
1.1058156e-11 8.6905420e-02 5.0035924e-01 1.2831078e-31 1.0453582e-26
5.6953936e-06 6.0707825e-04]
Word: good -- % perturbed: 1.0
Drift? Yes!
p-value: [1.9951835e-01 9.9365413e-01 4.2815179e-02 9.9954331e-01 1.9951835e-01
9.9693102e-01 4.6576622e-01 9.3558097e-01 9.3558097e-01 6.8523139e-01
9.5405817e-01 8.5929435e-01 2.2463709e-02 9.9987090e-01 9.9365413e-01
4.6576622e-01 9.9999607e-01 5.7265478e-01 9.9954331e-01 8.6905420e-02
9.1347551e-01 8.5929435e-01 2.8769323e-01 8.5929435e-01 8.5929435e-01
9.9693102e-01 9.9693102e-01 9.9987090e-01 7.9439443e-01 1.6396579e-04
9.1347551e-01 9.6887016e-01]
Word: good -- % perturbed: 5.0
Drift? Yes!
p-value: [6.1319246e-16 5.7265478e-01 9.2428386e-20 9.8826110e-01 6.1410643e-10
3.6997250e-01 3.2578668e-05 4.0047103e-01 1.2274054e-03 5.4185481e-08
7.2255498e-01 2.4034434e-03 1.8678642e-37 7.2131259e-03 1.3383430e-01
3.3837957e-10 3.4099194e-01 8.2482254e-10 9.3558097e-01 1.4201371e-13
9.5405817e-01 1.7943768e-06 1.9783097e-07 1.6396579e-04 8.6666952e-04
6.1551444e-02 4.2815179e-02 3.6997250e-01 6.1497598e-09 0.0000000e+00
5.3222836e-03 5.0035924e-01]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.64755726 0.6852314 0.2406036 0.96887016 0.50035924 0.82795686
0.9999727 0.82795686 0.96887016 0.5360543 0.96887016 0.6852314
0.00203786 0.9882611 0.96887016 0.5726548 0.26338065 0.09710453
0.7590978 0.34099194 0.1640792 0.6852314 0.12050407 0.79439443
0.99870795 0.99870795 0.99999607 0.9540582 0.6852314 0.18111965
0.12050407 0.99999607]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [3.18301190e-08 3.49877549e-09 4.19758215e-16 5.32228360e-03
1.71956726e-05 1.99518353e-01 6.47557259e-01 1.33834302e-01
1.98871276e-04 2.00923371e-13 4.00471032e-01 3.27475419e-07
0.00000000e+00 2.26972293e-06 3.13561678e-01 2.92505771e-02
2.63813617e-05 8.49670826e-18 1.64079204e-01 7.04859247e-08
1.29587926e-24 1.07232365e-08 4.14164800e-17 2.24637091e-02
2.87693232e-01 6.09918952e-01 1.48338065e-01 9.69783217e-03
1.00300261e-10 3.24872937e-19 1.88342838e-17 5.72654784e-01]
Word: horrible -- % perturbed: 1.0
Drift? Yes!
p-value: [0.28769323 0.9540582 0.9998709 0.93558097 0.64755726 0.9540582
0.9540582 0.99365413 0.01293455 0.99999607 0.99870795 0.5360543
0.06155144 0.9998709 0.8879386 0.9540582 0.02925058 0.46576622
0.93558097 0.3699725 0.28769323 0.96887016 0.26338065 0.996931
0.9998709 0.02925058 0.9134755 0.996931 0.996931 0.00145631
0.2406036 0.79439443]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [9.0348902e-17 3.6997250e-01 5.0035924e-01 9.1347551e-01 1.2274054e-03
1.1211077e-02 6.1551444e-02 2.8289410e-03 1.1508834e-28 7.2607823e-04
2.8769323e-01 3.3837957e-10 4.3195025e-41 3.3278044e-02 7.2607823e-04
3.6997250e-01 4.6432427e-09 2.4072561e-04 6.1410643e-10 7.0521917e-14
8.7644167e-30 1.1119027e-05 3.9587313e-15 7.2255498e-01 3.6098195e-06
1.3007273e-15 8.5929435e-01 1.2274054e-03 7.0521917e-14 0.0000000e+00
1.8887657e-15 4.8418805e-02]
MMD detector¶
Initialize¶
Again check the images example or the MMD detector documentation for more information about each of the possible parameters.
[21]:
cd = MMDDrift(
p_val=.05,
X_ref=X_ref, # reference data to test against
preprocess_X_ref=True, # store preprocessed X_ref for future predict calls
preprocess_kwargs=preprocess_kwargs,
chunk_size=1000,
n_permutations=10 # nb of permutations in the test, set to 10 for runtime
) # purposes; should be much higher for a real test
Detect drift¶
H0:
[22]:
preds_h0 = cd.predict(X_h0, return_p_val=True)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
Drift? No!
p-value: 1.0
Imbalanced data:
[23]:
for k, v in X_imb.items():
preds = cd.predict(v, return_p_val=True)
print('% negative sentiment {}'.format(k * 100))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
% negative sentiment 10.0
Drift? Yes!
p-value: 0.0
% negative sentiment 90.0
Drift? Yes!
p-value: 0.0
Perturbed data:
[24]:
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v, return_p_val=True)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: 0.3
Word: fantastic -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: good -- % perturbed: 1.0
Drift? No!
p-value: 1.0
Word: good -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: 1.0
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: 0.6
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
The MMD detector is less sensitive than the K-S drift detector for the perturbed sentences.
Train embeddings from scratch¶
So far we used pre-trained embeddings from a BERT model. We can however also use embeddings from a model trained from scratch. First we define and train a simple classification model consisting of an embedding and LSTM layer.
Load data and train model¶
[25]:
from tensorflow.keras.datasets import imdb, reuters
from tensorflow.keras.layers import Dense, Embedding, Input, LSTM
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.utils import to_categorical
INDEX_FROM = 3
NUM_WORDS = 10000
def print_sentence(tokenized_sentence: str, id2w: dict):
print(' '.join(id2w[_] for _ in tokenized_sentence))
print('')
print(tokenized_sentence)
def mapping_word_id(data):
w2id = data.get_word_index()
w2id = {k: (v + INDEX_FROM) for k, v in w2id.items()}
w2id["<PAD>"] = 0
w2id["<START>"] = 1
w2id["<UNK>"] = 2
w2id["<UNUSED>"] = 3
id2w = {v: k for k, v in w2id.items()}
return w2id, id2w
def get_dataset(dataset: str = 'imdb', max_len: int = 100):
if dataset == 'imdb':
data = imdb
elif dataset == 'reuters':
data = reuters
else:
raise NotImplementedError
w2id, id2w = mapping_word_id(data)
(X_train, y_train), (X_test, y_test) = data.load_data(
num_words=NUM_WORDS, index_from=INDEX_FROM)
X_train = sequence.pad_sequences(X_train, maxlen=max_len)
X_test = sequence.pad_sequences(X_test, maxlen=max_len)
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
return (X_train, y_train), (X_test, y_test), (w2id, id2w)
def imdb_model(X: np.ndarray, num_words: int = 100, emb_dim: int = 128,
lstm_dim: int = 128, output_dim: int = 2) -> tf.keras.Model:
inputs = Input(shape=(X.shape[1:]), dtype=tf.float32)
x = Embedding(num_words, emb_dim)(inputs)
x = LSTM(lstm_dim, dropout=.5)(x)
outputs = Dense(output_dim, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(
loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
return model
Load and tokenize data:
[26]:
(X_train, y_train), (X_test, y_test), (word2token, token2word) = \
get_dataset(dataset='imdb', max_len=max_len)
Let’s check out an instance:
[27]:
print_sentence(X_train[0], token2word)
cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
[1415 33 6 22 12 215 28 77 52 5 14 407 16 82
2 8 4 107 117 5952 15 256 4 2 7 3766 5 723
36 71 43 530 476 26 400 317 46 7 4 2 1029 13
104 88 4 381 15 297 98 32 2071 56 26 141 6 194
7486 18 4 226 22 21 134 476 26 480 5 144 30 5535
18 51 36 28 224 92 25 104 4 226 65 16 38 1334
88 12 16 283 5 16 4472 113 103 32 15 16 5345 19
178 32]
Define and train a simple model:
[28]:
model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2)
model.fit(X_train, y_train, batch_size=32, epochs=2,
shuffle=True, validation_data=(X_test, y_test))
Train on 25000 samples, validate on 25000 samples
Epoch 1/2
25000/25000 [==============================] - 123s 5ms/sample - loss: 0.4250 - accuracy: 0.8016 - val_loss: 0.3550 - val_accuracy: 0.8460
Epoch 2/2
25000/25000 [==============================] - 123s 5ms/sample - loss: 0.2759 - accuracy: 0.8900 - val_loss: 0.3782 - val_accuracy: 0.8367
[28]:
<tensorflow.python.keras.callbacks.History at 0x7ff1a8749d90>
Extract the embedding layer from the trained model and combine with UAE preprocessing step:
[29]:
Embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output)
emb = Embedding(X_train[:5])
print(emb.shape)
(5, 100, 256)
[30]:
tf.random.set_seed(0)
shape = tuple(emb.shape[1:])
uae = UAE(input_layer=Embedding, shape=shape, enc_dim=enc_dim)
Again, create reference, H0 and perturbed datasets. Also test against the Reuters news topic classification dataset.
[31]:
X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
tokens = [word2token[w] for w in words]
X_word = {}
for i, t in enumerate(tokens):
X_word[words[i]] = {}
for p in perc_chg:
X_word[words[i]][p] = inject_word(t, X_ref, p, padding='first')
[32]:
# load and tokenize Reuters dataset
(X_reut, y_reut), (w2t_reut, t2w_reut) = \
get_dataset(dataset='reuters', max_len=max_len)[1:]
# sample random instances
idx = np.random.choice(X_reut.shape[0], n_sample, replace=False)
X_ood = X_reut[idx]
Initialize detector and detect drift¶
[33]:
# no need for a tokenizer since we are already working with tokenized instances
preprocess_kwargs = {'model': uae, 'batch_size': 128}
cd = KSDrift(
p_val=.05,
X_ref=X_ref,
preprocess_X_ref=True,
preprocess_kwargs=preprocess_kwargs
)
H0:
[34]:
preds_h0 = cd.predict(X_h0, return_p_val=True)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
Drift? No!
p-value: [0.64755726 0.43243074 0.8879386 0.2406036 0.9540582 0.9134755
0.85929435 0.9134755 0.08690542 0.96887016 0.722555 0.9134755
0.9134755 0.8879386 0.85929435 0.40047103 0.18111965 0.19951835
0.1338343 0.50035924 0.9540582 0.9134755 0.28769323 0.5360543
0.28769323 0.34099194 0.85929435 0.04841881 0.21933001 0.6852314
0.3699725 0.5360543 ]
Perturbed data:
[35]:
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v, return_p_val=True)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: [0.9998709 0.9882611 0.9134755 0.8879386 0.9540582 0.9540582
0.60991895 0.93558097 0.28769323 0.99870795 0.82795686 0.82795686
0.5360543 0.9540582 0.6852314 0.99870795 0.9134755 0.99870795
0.82795686 0.9134755 0.9801618 0.996931 0.82795686 0.7590978
0.99870795 0.96887016 0.9801618 0.9882611 0.5726548 0.9995433
0.8879386 0.99365413]
Word: fantastic -- % perturbed: 5.0
Drift? Yes!
p-value: [1.2274054e-03 3.6997250e-01 2.8769323e-01 7.5909781e-01 2.6338065e-01
1.3383430e-01 9.0799862e-05 7.7621467e-02 2.2697229e-06 8.5929435e-01
5.3605431e-01 7.9439443e-01 7.7621467e-02 4.6576622e-01 4.3243074e-01
7.9439443e-01 3.4099194e-01 8.5929435e-01 8.6905420e-02 3.3278044e-02
1.4833806e-01 8.2795686e-01 4.6576622e-01 2.8769323e-01 1.3383430e-01
6.0991895e-01 3.1356168e-01 1.6407920e-01 8.8793862e-01 1.9626908e-02
7.7621467e-02 2.4060360e-01]
Word: good -- % perturbed: 1.0
Drift? No!
p-value: [0.9995433 0.9882611 0.9134755 0.9995433 0.99870795 0.99365413
0.85929435 0.99870795 0.9882611 0.9999727 0.9995433 0.99870795
0.9995433 0.9540582 0.99365413 0.82795686 0.9998709 0.93558097
0.9998709 0.99870795 0.9540582 0.96887016 0.79439443 0.996931
0.7590978 0.93558097 0.9882611 0.996931 0.96887016 0.96887016
0.99365413 0.93558097]
Word: good -- % perturbed: 5.0
Drift? No!
p-value: [0.82795686 0.996931 0.34099194 0.60991895 0.9801618 0.9134755
0.10828251 0.9882611 0.9882611 0.31356168 0.28769323 0.46576622
0.9134755 0.93558097 0.46576622 0.50035924 0.3699725 0.9801618
0.722555 0.60991895 0.7590978 0.9540582 0.722555 0.93558097
0.21933001 0.9134755 0.43243074 0.85929435 0.5726548 0.722555
0.6852314 0.07762147]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.9540582 0.64755726 0.9801618 0.99870795 0.9134755 0.9540582
0.9801618 0.82795686 0.8879386 0.7590978 0.99365413 0.99870795
0.8879386 0.8879386 0.93558097 0.9540582 0.7590978 0.64755726
0.82795686 0.96887016 0.99999607 0.9995433 0.96887016 0.722555
0.722555 0.93558097 0.6852314 0.93558097 0.79439443 0.8879386
0.9134755 0.9882611 ]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [7.21312594e-03 1.12110768e-02 1.20504074e-01 4.00471032e-01
5.00359237e-01 2.92505771e-02 3.13561678e-01 1.96269080e-02
1.08282514e-01 3.32780443e-02 1.71140861e-02 8.59294355e-01
8.37208051e-03 9.35580969e-01 7.59097815e-01 1.48931602e-02
7.94394433e-01 8.69054198e-02 8.69054198e-02 9.13475513e-01
5.00359237e-01 7.22554982e-01 7.59097815e-01 3.50604125e-04
1.20504074e-01 1.48338065e-01 3.77843790e-02 1.64079204e-01
2.40603596e-01 4.28151786e-02 2.40603596e-01 1.99518353e-01]
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: [0.99365413 0.85929435 0.9134755 0.85929435 0.79439443 0.7590978
0.79439443 0.31356168 0.64755726 0.46576622 0.9540582 0.7590978
0.43243074 0.9134755 0.9540582 0.60991895 0.8879386 0.64755726
0.85929435 0.9882611 0.93558097 0.93558097 0.99365413 0.79439443
0.9134755 0.43243074 0.64755726 0.9134755 0.64755726 0.93558097
0.93558097 0.9801618 ]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [1.71140861e-02 5.32228360e-03 4.28151786e-02 2.92505771e-02
2.92505771e-02 1.79437677e-06 4.55808453e-03 1.72444014e-03
6.07078255e-04 1.45630504e-03 1.72444014e-03 2.24637091e-02
8.00581650e-12 6.09918952e-01 1.08282514e-01 4.93855441e-05
1.99518353e-01 1.10792353e-04 7.21312594e-03 4.00471032e-01
4.00471032e-01 4.65766221e-01 3.40991944e-01 2.24637091e-02
7.26078229e-04 4.55808453e-03 8.76041099e-07 2.40344345e-03
1.07232365e-08 7.76214674e-02 1.99518353e-01 4.00471032e-01]
Again the detector is not as sensitive as the Transformer-based K-S drift detector. The embeddings trained from scratch only trained on a small dataset and a simple model with cross-entropy loss function for 2 epochs. The pre-trained BERT model on the other hand captures semantics of the data better.
Sample from the Reuters dataset:
[36]:
preds_ood = cd.predict(X_ood, return_p_val=True)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_ood['data']['is_drift']]))
print('p-value: {}'.format(preds_ood['data']['p_val']))
Drift? Yes!
p-value: [1.34916729e-04 2.83701773e-13 2.53623026e-18 4.28151786e-02
2.63380647e-01 6.20218972e-03 1.22740539e-03 3.77843790e-02
4.64324268e-09 1.08282514e-01 1.84965307e-10 3.32780443e-02
7.13247118e-06 3.24872937e-19 1.64079204e-01 1.00300261e-10
8.91430864e-06 1.45630504e-03 1.79437677e-06 5.46463318e-02
6.09918952e-01 1.47906931e-09 8.12879719e-09 1.34916729e-04
7.76214674e-02 5.46463318e-02 1.47906931e-09 1.71956726e-05
1.47906931e-09 2.13202584e-05 7.13247118e-06 1.22740539e-03]