torchbearer.callbacks

class torchbearer.callbacks.callbacks.Callback[source]

Base callback class.

Note

All callbacks should override this class.

on_backward(state)[source]

Perform some action with the given state as context after backward has been called on the loss.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion(state)[source]

Perform some action with the given state as context after the criterion has been evaluated.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_training(state)[source]

Perform some action with the given state as context after the training loop has completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Perform some action with the given state as context at the end of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward(state)[source]

Perform some action with the given state as context after the forward pass (model output) has been completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward_validation(state)[source]

Perform some action with the given state as context after the forward pass (model output) has been completed with the validation data.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample_validation(state)[source]

Perform some action with the given state as context after data has been sampled from the validation generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Perform some action with the given state as context at the start of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Perform some action with the given state as context at the start of the training loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_validation(state)[source]

Perform some action with the given state as context at the start of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.callbacks.CallbackList(callback_list)[source]

The CallbackList class is a wrapper for a list of callbacks which acts as a single callback.

on_backward(state)[source]

Call on_backward on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion(state)[source]

Call on_criterion on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end(state)[source]

Call on_end on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Call on_end_epoch on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_training(state)[source]

Call on_end_training on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Call on_end_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward(state)[source]

Call on_forward on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward_validation(state)[source]

Call on_forward_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Call on_sample on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample_validation(state)[source]

Call on_sample_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Call on_start on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Call on_start_epoch on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Call on_start_training on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_validation(state)[source]

Call on_start_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Call on_step_training on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Call on_step_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Model Checkpointers

class torchbearer.callbacks.checkpointers.Best(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='auto', period=1, min_delta=0, pickle_module=<MagicMock name='mock.pickle' id='139937544414544'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139937544805064'>)[source]

Model checkpointer which saves the best model according to a metric.

on_end_epoch(model_state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.checkpointers.Interval(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', period=1, pickle_module=<MagicMock name='mock.pickle' id='139937544735432'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139937544670680'>)[source]

Model checkpointer which saves the model every given number of epochs.

on_end_epoch(model_state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
torchbearer.callbacks.checkpointers.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', save_best_only=False, mode='auto', period=1, min_delta=0)[source]

Save the model after every epoch. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The torch model will be saved to filename.pt and the torchbearermodel state will be saved to filename.torchbearer.

Parameters:
  • filepath (str) – Path to save the model file
  • monitor (str) – Quantity to monitor
  • save_best_only (bool) – If save_best_only=True, the latest best model according to the quantity monitored will not be overwritten
  • mode (str) – One of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
  • period (int) – Interval (number of epochs) between checkpoints
  • min_delta (float) – If save_best_only=True, this is the minimum improvement required to trigger a save
class torchbearer.callbacks.checkpointers.MostRecent(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', pickle_module=<MagicMock name='mock.pickle' id='139937544339752'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139937544372912'>)[source]

Model checkpointer which saves the most recent model.

on_end_epoch(model_state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Logging

class torchbearer.callbacks.csv_logger.CSVLogger(filename, separator=', ', batch_granularity=False, write_header=True, append=False)[source]

Callback to log metrics to a csv file.

on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.printer.ConsolePrinter(validation_label_letter='v')[source]

The ConsolePrinter callback simply outputs the training metrics to the console.

on_end_training(state)[source]

Perform some action with the given state as context after the training loop has completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Perform some action with the given state as context at the end of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.printer.Tqdm(validation_label_letter='v')[source]

The Tqdm callback outputs the progress and metrics for training and validation loops to the console using TQDM.

on_end_training(state)[source]

Update the bar with the terminal training metrics and then close.

Parameters:state (dict) – The Model state
on_end_validation(state)[source]

Update the bar with the terminal validation metrics and then close.

Parameters:state (dict) – The Model state
on_start_training(state)[source]

Initialise the TQDM bar for this training phase.

Parameters:state (dict) – The Model state
on_start_validation(state)[source]

Initialise the TQDM bar for this validation phase.

Parameters:state (dict) – The Model state
on_step_training(state)[source]

Update the bar with the metrics from this step.

Parameters:state (dict) – The Model state
on_step_validation(state)[source]

Update the bar with the metrics from this step.

Parameters:state (dict) – The Model state

Tensorboard

class torchbearer.callbacks.tensor_board.TensorBoard(log_dir='./logs', write_graph=True, write_batch_metrics=False, batch_step_size=10, write_epoch_metrics=True, comment='torchbearer')[source]

The TensorBoard callback is used to write metric graphs to tensorboard. Requires the TensorboardX library for python.

on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Perform some action with the given state as context at the start of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.tensor_board.TensorBoardImages(log_dir='./logs', comment='torchbearer', name='Image', key='y_pred', write_each_epoch=True, num_images=16, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)[source]

The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the TensorboardX library and torchvision.utils.make_grid

on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.tensor_board.TensorBoardProjector(log_dir='./logs', comment='torchbearer', num_images=100, avg_pool_size=1, avg_data_channels=True, write_data=True, write_features=True, features_key='y_pred')[source]

The TensorBoardProjector callback is used to write images from the validation pass to Tensorboard using the TensorboardX library.

on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Early Stopping

class torchbearer.callbacks.early_stopping.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto')[source]

Callback to stop training when a monitored quantity has stopped improving.

on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.terminate_on_nan.TerminateOnNaN(monitor='running_loss')[source]

Callback that terminates training when the given metric is nan or inf.

on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Gradient Clipping

class torchbearer.callbacks.gradient_clipping.GradientClipping(clip_value, params=None)[source]

GradientClipping callback, uses ‘torch.nn.utils.clip_grad_value_’

on_backward(state)[source]

Between the backward pass (which computes the gradients) and the step call (which updates the parameters), clip the gradient.

Parameters:state (dict) – The Model state
on_start(state)[source]

If params is None then retrieve from the model.

Parameters:state (dict) – The Model state
class torchbearer.callbacks.gradient_clipping.GradientNormClipping(max_norm, norm_type=2, params=None)[source]

GradientNormClipping callback, uses ‘torch.nn.utils.clip_grad_norm_’

on_backward(state)[source]

Between the backward pass (which computes the gradients) and the step call (which updates the parameters), clip the gradient.

Parameters:state (dict) – The Model state
on_start(state)[source]

If params is None then retrieve from the model.

Parameters:state (dict) – The Model state

Learning Rate Schedulers

class torchbearer.callbacks.torch_scheduler.CosineAnnealingLR(T_max, eta_min=0, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch CosineAnnealingLR
class torchbearer.callbacks.torch_scheduler.ExponentialLR(gamma, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch ExponentialLR
class torchbearer.callbacks.torch_scheduler.LambdaLR(lr_lambda, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch LambdaLR
class torchbearer.callbacks.torch_scheduler.MultiStepLR(milestones, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch MultiStepLR
class torchbearer.callbacks.torch_scheduler.ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, step_on_batch=False)[source]
Parameters:monitor (str) – The quantity to monitor. (Default value = ‘val_loss’)
See:
PyTorch ReduceLROnPlateau
class torchbearer.callbacks.torch_scheduler.StepLR(step_size, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch StepLR
class torchbearer.callbacks.torch_scheduler.TorchScheduler(scheduler_builder, monitor=None, step_on_batch=False)[source]
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Perform some action with the given state as context at the start of the training loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Weight Decay

class torchbearer.callbacks.weight_decay.L1WeightDecay(rate=0.0005, params=None)[source]

WeightDecay callback which uses an L1 norm

class torchbearer.callbacks.weight_decay.L2WeightDecay(rate=0.0005, params=None)[source]

WeightDecay callback which uses an L2 norm

class torchbearer.callbacks.weight_decay.WeightDecay(rate=0.0005, p=2, params=None)[source]

Callback which adds a weight decay term to the loss for the given parameters.

on_criterion(state)[source]

Calculate the decay term and add to state[‘loss’].

Parameters:state (dict) – The Model state
on_start(state)[source]

Retrieve params from state[‘model’] if required.

Parameters:state (dict) – The Model state

Decorators

torchbearer.callbacks.decorators.add_to_loss(func)[source]

The add_to_loss() decorator is used to initialise a Callback with the value returned from func being added to the loss

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback which adds the returned value from func to the loss
Return type:Callback
torchbearer.callbacks.decorators.on_backward(func)[source]

The on_backward() decorator is used to initialise a Callback with on_backward() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_backward() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_criterion(func)[source]

The on_criterion() decorator is used to initialise a Callback with on_criterion() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_criterion() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end(func)[source]

The on_end() decorator is used to initialise a Callback with on_end() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end_epoch(func)[source]

The on_end_epoch() decorator is used to initialise a Callback with on_end_epoch() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end_epoch() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end_training(func)[source]

The on_end_training() decorator is used to initialise a Callback with on_end_training() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end_training() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end_validation(func)[source]

The on_end_validation() decorator is used to initialise a Callback with on_end_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_forward(func)[source]

The on_forward() decorator is used to initialise a Callback with on_forward() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_forward() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_forward_validation(func)[source]

The on_forward_validation() decorator is used to initialise a Callback with on_forward_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_forward_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_sample(func)[source]

The on_sample() decorator is used to initialise a Callback with on_sample() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_sample() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_sample_validation(func)[source]

The on_sample_validation() decorator is used to initialise a Callback with on_sample_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_sample_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start(func)[source]

The on_start() decorator is used to initialise a Callback with on_start() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with on_start() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start_epoch(func)[source]

The on_start_epoch() decorator is used to initialise a Callback with on_start_epoch() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with on_start_epoch() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start_training(func)[source]

The on_start_training() decorator is used to initialise a Callback with on_start_training() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_start_training() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start_validation(func)[source]

The on_start_validation() decorator is used to initialise a Callback with on_start_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_start_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_step_training(func)[source]

The on_step_training() decorator is used to initialise a Callback with on_step_training() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_step_training() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_step_validation(func)[source]

The on_step_validation() decorator is used to initialise a Callback with on_step_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_step_validation() calling func
Return type:Callback