
class, kernel, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_x_ref=None, preprocess_fn=None, n_permutations=100, var_reg=1e-05, reg_loss_fn=<function LearnedKernelDriftTF.<lambda>>, train_size=0.75, retrain_from_scratch=True, optimizer=tensorflow.keras.optimizers.Adam, learning_rate=0.001, batch_size=32, preprocess_batch_fn=None, epochs=3, verbose=0, train_kwargs=None, dataset=<class ''>, input_shape=None, data_type=None)[source]

Bases: BaseLearnedKernelDrift

class JHat(kernel, var_reg)[source]

Bases: Model

A module that wraps around the kernel. When passed a batch of reference and batch of test instances it returns an estimate of a correlate of test power. Equation 4 of

call(x, y)[source]
Return type


__init__(x_ref, kernel, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_x_ref=None, preprocess_fn=None, n_permutations=100, var_reg=1e-05, reg_loss_fn=<function LearnedKernelDriftTF.<lambda>>, train_size=0.75, retrain_from_scratch=True, optimizer=tensorflow.keras.optimizers.Adam, learning_rate=0.001, batch_size=32, preprocess_batch_fn=None, epochs=3, verbose=0, train_kwargs=None, dataset=<class ''>, input_shape=None, data_type=None)[source]

Maximum Mean Discrepancy (MMD) data drift detector where the kernel is trained to maximise an estimate of the test power. The kernel is trained on a split of the reference and test instances and then the MMD is evaluated on held out instances and a permutation test is performed.

For details see Liu et al (2020): Learning Deep Kernels for Non-Parametric Two-Sample Tests (

  • x_ref (Union[ndarray, list]) – Data used as reference distribution.

  • kernel (Model) – Trainable TensorFlow model that returns a similarity between two instances.

  • p_val (float) – p-value used for the significance of the test.

  • x_ref_preprocessed (bool) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.

  • preprocess_at_init (bool) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.

  • update_x_ref (Optional[Dict[str, int]]) – Reference data can optionally be updated to the last n instances seen by the detector or via reservoir sampling with size n. For the former, the parameter equals {‘last’: n} while for reservoir sampling {‘reservoir_sampling’: n} is passed.

  • preprocess_fn (Optional[Callable]) – Function to preprocess the data before applying the kernel.

  • n_permutations (int) – The number of permutations to use in the permutation test once the MMD has been computed.

  • var_reg (float) – Constant added to the estimated variance of the MMD for stability.

  • reg_loss_fn (Callable) – The regularisation term reg_loss_fn(kernel) is added to the loss function being optimized.

  • train_size (Optional[float]) – Optional fraction (float between 0 and 1) of the dataset used to train the kernel. The drift is detected on 1 - train_size.

  • retrain_from_scratch (bool) – Whether the kernel should be retrained from scratch for each set of test data or whether it should instead continue training from where it left off on the previous set.

  • optimizer (Optimizer) – Optimizer used during training of the kernel.

  • learning_rate (float) – Learning rate used by optimizer.

  • batch_size (int) – Batch size used during training of the kernel.

  • preprocess_batch_fn (Optional[Callable]) – Optional batch preprocessing function. For example to convert a list of objects to a batch which can be processed by the kernel.

  • epochs (int) – Number of training epochs for the kernel. Corresponds to the smaller of the reference and test sets.

  • verbose (int) – Verbosity level during the training of the kernel. 0 is silent, 1 a progress bar.

  • train_kwargs (Optional[dict]) – Optional additional kwargs when training the kernel.

  • dataset (Callable) – Dataset object used during training.

  • input_shape (Optional[tuple]) – Shape of input data.

  • data_type (Optional[str]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.

kernel: Union[tf.keras.Model, torch.nn.Module]

Compute the p-value resulting from a permutation test using the maximum mean discrepancy as a distance measure between the reference data and the data to be tested. The kernel used within the MMD is first trained to maximise an estimate of the resulting test power.


x (Union[ndarray, list]) – Batch of instances.

Return type

Tuple[float, float, float]


  • p-value obtained from the permutation test, the MMD^2 between the reference and test set,

  • and the MMD^2 threshold above which drift is flagged.

static trainer(j_hat, datasets, optimizer=tensorflow.keras.optimizers.Adam, learning_rate=0.001, preprocess_fn=None, epochs=20, reg_loss_fn=<function LearnedKernelDriftTF.<lambda>>, verbose=1)[source]

Train the kernel to maximise an estimate of test power using minibatch gradient descent.

Return type
