alibi_detect.models.pytorch.trainer module

alibi_detect.models.pytorch.trainer.trainer(model, loss_fn, dataloader, device, optimizer=torch.optim.Adam, learning_rate=0.001, preprocess_fn=None, epochs=20, reg_loss_fn=<function <lambda>>, verbose=1)[source]

Train PyTorch model.

Parameters:
  • model (Union[Module, Sequential]) – Model to train.

  • loss_fn (Callable) – Loss function used for training.

  • dataloader (DataLoader) – PyTorch dataloader.

  • device (device) – Device used for training.

  • optimizer (Callable) – Optimizer used for training.

  • learning_rate (float) – Optimizer’s learning rate.

  • preprocess_fn (Optional[Callable]) – Preprocessing function applied to each training batch.

  • epochs (int) – Number of training epochs.

  • reg_loss_fn (Callable) – The regularisation term reg_loss_fn(model) is added to the loss function being optimized.

  • verbose (int) – Whether to print training progress.

Return type:

None