torchbearer.callbacks¶
-
class
torchbearer.callbacks.callbacks.
Callback
[source]¶ Base callback class.
Note
All callbacks should override this class.
-
on_backward
(state)[source]¶ Perform some action with the given state as context after backward has been called on the loss.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_criterion
(state)[source]¶ Perform some action with the given state as context after the criterion has been evaluated.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_training
(state)[source]¶ Perform some action with the given state as context after the training loop has completed.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_validation
(state)[source]¶ Perform some action with the given state as context at the end of the validation loop.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_forward
(state)[source]¶ Perform some action with the given state as context after the forward pass (model output) has been completed.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_forward_validation
(state)[source]¶ Perform some action with the given state as context after the forward pass (model output) has been completed with the validation data.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_sample_validation
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the validation generator.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_epoch
(state)[source]¶ Perform some action with the given state as context at the start of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_training
(state)[source]¶ Perform some action with the given state as context at the start of the training loop.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_validation
(state)[source]¶ Perform some action with the given state as context at the start of the validation loop.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
-
class
torchbearer.callbacks.callbacks.
CallbackList
(callback_list)[source]¶ The
CallbackList
class is a wrapper for a list of callbacks which acts as a single callback.-
on_backward
(state)[source]¶ Call on_backward on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_criterion
(state)[source]¶ Call on_criterion on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end
(state)[source]¶ Call on_end on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_epoch
(state)[source]¶ Call on_end_epoch on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_training
(state)[source]¶ Call on_end_training on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_validation
(state)[source]¶ Call on_end_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_forward
(state)[source]¶ Call on_forward on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_forward_validation
(state)[source]¶ Call on_forward_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_sample
(state)[source]¶ Call on_sample on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_sample_validation
(state)[source]¶ Call on_sample_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start
(state)[source]¶ Call on_start on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_epoch
(state)[source]¶ Call on_start_epoch on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_training
(state)[source]¶ Call on_start_training on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_validation
(state)[source]¶ Call on_start_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
Model Checkpointers¶
-
class
torchbearer.callbacks.checkpointers.
Best
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='auto', period=1, min_delta=0, pickle_module=<MagicMock name='mock.pickle' id='139937544414544'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139937544805064'>)[source]¶ Model checkpointer which saves the best model according to a metric.
-
class
torchbearer.callbacks.checkpointers.
Interval
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', period=1, pickle_module=<MagicMock name='mock.pickle' id='139937544735432'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139937544670680'>)[source]¶ Model checkpointer which saves the model every given number of epochs.
-
torchbearer.callbacks.checkpointers.
ModelCheckpoint
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', save_best_only=False, mode='auto', period=1, min_delta=0)[source]¶ Save the model after every epoch. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The torch model will be saved to filename.pt and the torchbearermodel state will be saved to filename.torchbearer.
Parameters: - filepath (str) – Path to save the model file
- monitor (str) – Quantity to monitor
- save_best_only (bool) – If save_best_only=True, the latest best model according to the quantity monitored will not be overwritten
- mode (str) – One of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
- period (int) – Interval (number of epochs) between checkpoints
- min_delta (float) – If save_best_only=True, this is the minimum improvement required to trigger a save
-
class
torchbearer.callbacks.checkpointers.
MostRecent
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', pickle_module=<MagicMock name='mock.pickle' id='139937544339752'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139937544372912'>)[source]¶ Model checkpointer which saves the most recent model.
Logging¶
-
class
torchbearer.callbacks.csv_logger.
CSVLogger
(filename, separator=', ', batch_granularity=False, write_header=True, append=False)[source]¶ Callback to log metrics to a csv file.
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
-
class
torchbearer.callbacks.printer.
ConsolePrinter
(validation_label_letter='v')[source]¶ The ConsolePrinter callback simply outputs the training metrics to the console.
-
on_end_training
(state)[source]¶ Perform some action with the given state as context after the training loop has completed.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_validation
(state)[source]¶ Perform some action with the given state as context at the end of the validation loop.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
-
class
torchbearer.callbacks.printer.
Tqdm
(validation_label_letter='v')[source]¶ The Tqdm callback outputs the progress and metrics for training and validation loops to the console using TQDM.
-
on_end_training
(state)[source]¶ Update the bar with the terminal training metrics and then close.
Parameters: state (dict) – The Model state
-
on_end_validation
(state)[source]¶ Update the bar with the terminal validation metrics and then close.
Parameters: state (dict) – The Model state
-
on_start_training
(state)[source]¶ Initialise the TQDM bar for this training phase.
Parameters: state (dict) – The Model state
-
on_start_validation
(state)[source]¶ Initialise the TQDM bar for this validation phase.
Parameters: state (dict) – The Model state
-
Tensorboard¶
-
class
torchbearer.callbacks.tensor_board.
TensorBoard
(log_dir='./logs', write_graph=True, write_batch_metrics=False, batch_step_size=10, write_epoch_metrics=True, comment='torchbearer')[source]¶ The TensorBoard callback is used to write metric graphs to tensorboard. Requires the TensorboardX library for python.
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start_epoch
(state)[source]¶ Perform some action with the given state as context at the start of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
-
class
torchbearer.callbacks.tensor_board.
TensorBoardImages
(log_dir='./logs', comment='torchbearer', name='Image', key='y_pred', write_each_epoch=True, num_images=16, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)[source]¶ The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the TensorboardX library and torchvision.utils.make_grid
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
-
class
torchbearer.callbacks.tensor_board.
TensorBoardProjector
(log_dir='./logs', comment='torchbearer', num_images=100, avg_pool_size=1, avg_data_channels=True, write_data=True, write_features=True, features_key='y_pred')[source]¶ The TensorBoardProjector callback is used to write images from the validation pass to Tensorboard using the TensorboardX library.
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
Early Stopping¶
-
class
torchbearer.callbacks.early_stopping.
EarlyStopping
(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto')[source]¶ Callback to stop training when a monitored quantity has stopped improving.
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
-
class
torchbearer.callbacks.terminate_on_nan.
TerminateOnNaN
(monitor='running_loss')[source]¶ Callback that terminates training when the given metric is nan or inf.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
Gradient Clipping¶
-
class
torchbearer.callbacks.gradient_clipping.
GradientClipping
(clip_value, params=None)[source]¶ GradientClipping callback, uses ‘torch.nn.utils.clip_grad_value_’
-
class
torchbearer.callbacks.gradient_clipping.
GradientNormClipping
(max_norm, norm_type=2, params=None)[source]¶ GradientNormClipping callback, uses ‘torch.nn.utils.clip_grad_norm_’
Learning Rate Schedulers¶
-
class
torchbearer.callbacks.torch_scheduler.
CosineAnnealingLR
(T_max, eta_min=0, last_epoch=-1, step_on_batch=False)[source]¶
-
class
torchbearer.callbacks.torch_scheduler.
ExponentialLR
(gamma, last_epoch=-1, step_on_batch=False)[source]¶
-
class
torchbearer.callbacks.torch_scheduler.
LambdaLR
(lr_lambda, last_epoch=-1, step_on_batch=False)[source]¶ - See:
- PyTorch LambdaLR
-
class
torchbearer.callbacks.torch_scheduler.
MultiStepLR
(milestones, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]¶ - See:
- PyTorch MultiStepLR
-
class
torchbearer.callbacks.torch_scheduler.
ReduceLROnPlateau
(monitor='val_loss', mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, step_on_batch=False)[source]¶ Parameters: monitor (str) – The quantity to monitor. (Default value = ‘val_loss’)
-
class
torchbearer.callbacks.torch_scheduler.
StepLR
(step_size, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]¶ - See:
- PyTorch StepLR
-
class
torchbearer.callbacks.torch_scheduler.
TorchScheduler
(scheduler_builder, monitor=None, step_on_batch=False)[source]¶ -
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict[str,any]) – The current state dict of the Model
.
-
Weight Decay¶
-
class
torchbearer.callbacks.weight_decay.
L1WeightDecay
(rate=0.0005, params=None)[source]¶ WeightDecay callback which uses an L1 norm
-
class
torchbearer.callbacks.weight_decay.
L2WeightDecay
(rate=0.0005, params=None)[source]¶ WeightDecay callback which uses an L2 norm
Decorators¶
-
torchbearer.callbacks.decorators.
add_to_loss
(func)[source]¶ The
add_to_loss()
decorator is used to initialise aCallback
with the value returned from func being added to the lossParameters: func (function) – The function(state) to decorate Returns: Initialised callback which adds the returned value from func to the loss Return type: Callback
-
torchbearer.callbacks.decorators.
on_backward
(func)[source]¶ The
on_backward()
decorator is used to initialise aCallback
withon_backward()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_backward()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_criterion
(func)[source]¶ The
on_criterion()
decorator is used to initialise aCallback
withon_criterion()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_criterion()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end
(func)[source]¶ The
on_end()
decorator is used to initialise aCallback
withon_end()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_end()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end_epoch
(func)[source]¶ The
on_end_epoch()
decorator is used to initialise aCallback
withon_end_epoch()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_end_epoch()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end_training
(func)[source]¶ The
on_end_training()
decorator is used to initialise aCallback
withon_end_training()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_end_training()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end_validation
(func)[source]¶ The
on_end_validation()
decorator is used to initialise aCallback
withon_end_validation()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_end_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_forward
(func)[source]¶ The
on_forward()
decorator is used to initialise aCallback
withon_forward()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_forward()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_forward_validation
(func)[source]¶ The
on_forward_validation()
decorator is used to initialise aCallback
withon_forward_validation()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_forward_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_sample
(func)[source]¶ The
on_sample()
decorator is used to initialise aCallback
withon_sample()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_sample()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_sample_validation
(func)[source]¶ The
on_sample_validation()
decorator is used to initialise aCallback
withon_sample_validation()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_sample_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start
(func)[source]¶ The
on_start()
decorator is used to initialise aCallback
withon_start()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_start()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start_epoch
(func)[source]¶ The
on_start_epoch()
decorator is used to initialise aCallback
withon_start_epoch()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_start_epoch()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start_training
(func)[source]¶ The
on_start_training()
decorator is used to initialise aCallback
withon_start_training()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_start_training()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start_validation
(func)[source]¶ The
on_start_validation()
decorator is used to initialise aCallback
withon_start_validation()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_start_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_step_training
(func)[source]¶ The
on_step_training()
decorator is used to initialise aCallback
withon_step_training()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_step_training()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_step_validation
(func)[source]¶ The
on_step_validation()
decorator is used to initialise aCallback
withon_step_validation()
calling the decorated functionParameters: func (function) – The function(state) to decorate Returns: Initialised callback with Callback.on_step_validation()
calling funcReturn type: Callback