torchbearer.callbacks

class torchbearer.callbacks.callbacks.Callback[source]

Base callback class.

Note

All callbacks should override this class.

load_state_dict(state_dict)[source]

Resume this callback from the given state. Expects that this callback was constructed in the same way.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:Callback
on_backward(state)[source]

Perform some action with the given state as context after backward has been called on the loss.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_checkpoint(state)[source]

Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion(state)[source]

Perform some action with the given state as context after the criterion has been evaluated.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion_validation(state)[source]

Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_training(state)[source]

Perform some action with the given state as context after the training loop has completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Perform some action with the given state as context at the end of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward(state)[source]

Perform some action with the given state as context after the forward pass (model output) has been completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward_validation(state)[source]

Perform some action with the given state as context after the forward pass (model output) has been completed with the validation data.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample_validation(state)[source]

Perform some action with the given state as context after data has been sampled from the validation generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Perform some action with the given state as context at the start of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Perform some action with the given state as context at the start of the training loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_validation(state)[source]

Perform some action with the given state as context at the start of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
state_dict()[source]

Get a dict containing the callback state.

Returns:A dict containing parameters and persistent buffers.
Return type:dict
class torchbearer.callbacks.callbacks.CallbackList(callback_list)[source]

The CallbackList class is a wrapper for a list of callbacks which acts as a single Callback and internally calls each Callback in the given list in turn.

:param callback_list:The list of callbacks to be wrapped. If the list contains a CallbackList, this will be unwrapped. :type callback_list:list

CALLBACK_STATES = 'callback_states'
CALLBACK_TYPES = 'callback_types'
append(callback_list)[source]
copy()[source]
load_state_dict(state_dict)[source]

Resume this callback list from the given state. Callbacks must be given in the same order for this to work.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:CallbackList
on_backward(state)[source]

Call on_backward on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_checkpoint(state)[source]

Call on_checkpoint on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion(state)[source]

Call on_criterion on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion_validation(state)[source]

Call on_criterion_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end(state)[source]

Call on_end on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Call on_end_epoch on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_training(state)[source]

Call on_end_training on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Call on_end_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward(state)[source]

Call on_forward on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward_validation(state)[source]

Call on_forward_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Call on_sample on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample_validation(state)[source]

Call on_sample_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Call on_start on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Call on_start_epoch on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Call on_start_training on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_validation(state)[source]

Call on_start_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Call on_step_training on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Call on_step_validation on each callback in turn with the given state.

Parameters:state (dict[str,any]) – The current state dict of the Model.
state_dict()[source]

Get a dict containing all of the callback states.

Returns:A dict containing parameters and persistent buffers.
Return type:dict

Model Checkpointers

class torchbearer.callbacks.checkpointers.Best(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='auto', period=1, min_delta=0, pickle_module=<MagicMock name='mock.pickle' id='139909960818872'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139909960835648'>)[source]

Model checkpointer which saves the best model according to the given configurations.

Parameters:
  • filepath (str) – Path to save the model file
  • monitor (str) – Quantity to monitor
  • mode (str) – One of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
  • period (int) – Interval (number of epochs) between checkpoints
  • min_delta (float) – This is the minimum improvement required to trigger a save
  • pickle_module – The pickle module to use, default is ‘torch.serialization.pickle’
  • pickle_protocol – The pickle protocol to use, default is ‘torch.serialization.DEFAULT_PROTOCOL’
load_state_dict(state_dict)[source]

Resume this callback from the given state. Expects that this callback was constructed in the same way.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:Callback
on_checkpoint(state)[source]

Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
state_dict()[source]

Get a dict containing the callback state.

Returns:A dict containing parameters and persistent buffers.
Return type:dict
class torchbearer.callbacks.checkpointers.Interval(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', period=1, pickle_module=<MagicMock name='mock.pickle' id='139909960332624'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139909960353496'>)[source]

Model checkpointer which which saves the model every ‘period’ epochs to the given filepath.

Parameters:
  • filepath (str) – Path to save the model file
  • period (int) – Interval (number of epochs) between checkpoints
  • pickle_module – The pickle module to use, default is ‘torch.serialization.pickle’
  • pickle_protocol – The pickle protocol to use, default is ‘torch.serialization.DEFAULT_PROTOCOL’
load_state_dict(state_dict)[source]

Resume this callback from the given state. Expects that this callback was constructed in the same way.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:Callback
on_checkpoint(state)[source]

Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.

Parameters:state (dict[str,any]) – The current state dict of the Model.
state_dict()[source]

Get a dict containing the callback state.

Returns:A dict containing parameters and persistent buffers.
Return type:dict
torchbearer.callbacks.checkpointers.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', save_best_only=False, mode='auto', period=1, min_delta=0)[source]

Save the model after every epoch. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The torch model will be saved to filename.pt and the torchbearermodel state will be saved to filename.torchbearer.

Parameters:
  • filepath (str) – Path to save the model file
  • monitor (str) – Quantity to monitor
  • save_best_only (bool) – If save_best_only=True, the latest best model according to the quantity monitored will not be overwritten
  • mode (str) – One of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
  • period (int) – Interval (number of epochs) between checkpoints
  • min_delta (float) – If save_best_only=True, this is the minimum improvement required to trigger a save
class torchbearer.callbacks.checkpointers.MostRecent(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', pickle_module=<MagicMock name='mock.pickle' id='139909960772688'>, pickle_protocol=<MagicMock name='mock.DEFAULT_PROTOCOL' id='139909960789464'>)[source]

Model checkpointer which saves the most recent model to a given filepath.

Parameters:
  • filepath (str) – Path to save the model file
  • pickle_module – The pickle module to use, default is ‘torch.serialization.pickle’
  • pickle_protocol – The pickle protocol to use, default is ‘torch.serialization.DEFAULT_PROTOCOL’
on_checkpoint(state)[source]

Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Logging

class torchbearer.callbacks.csv_logger.CSVLogger(filename, separator=', ', batch_granularity=False, write_header=True, append=False)[source]

Callback to log metrics to a given csv file.

Parameters:
  • filename (str) – The name of the file to output to
  • separator (str) – The delimiter to use (e.g. comma, tab etc.)
  • batch_granularity (bool) – If True, write on each batch, else on each epoch
  • write_header (bool) – If True, write the CSV header at the beginning of training
  • append (bool) – If True, append to the file instead of replacing it
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.printer.ConsolePrinter(validation_label_letter='v', precision=4)[source]

The ConsolePrinter callback simply outputs the training metrics to the console.

Parameters:
  • validation_label_letter (String) – This is the letter displayed after the epoch number indicating the current phase of training
  • precision (int) – Precision of the number format in significant figures
on_end_training(state)[source]

Perform some action with the given state as context after the training loop has completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Perform some action with the given state as context at the end of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.printer.Tqdm(tqdm_module=<MagicMock id='139909960393728'>, validation_label_letter='v', precision=4, on_epoch=False, **tqdm_args)[source]

The Tqdm callback outputs the progress and metrics for training and validation loops to the console using TQDM. The given key is used to label validation output.

Parameters:
  • validation_label_letter (str) – The letter to use for validation outputs.
  • precision (int) – Precision of the number format in significant figures
  • on_epoch (bool) – If True, output a single progress bar which tracks epochs
  • tqdm_args – Any extra keyword args provided here will be passed through to the tqdm module constructor. See github.com/tqdm/tqdm#parameters for more details.
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_training(state)[source]

Update the bar with the terminal training metrics and then close.

Parameters:state (dict) – The Model state
on_end_validation(state)[source]

Update the bar with the terminal validation metrics and then close.

Parameters:state (dict) – The Model state
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Initialise the TQDM bar for this training phase.

Parameters:state (dict) – The Model state
on_start_validation(state)[source]

Initialise the TQDM bar for this validation phase.

Parameters:state (dict) – The Model state
on_step_training(state)[source]

Update the bar with the metrics from this step.

Parameters:state (dict) – The Model state
on_step_validation(state)[source]

Update the bar with the metrics from this step.

Parameters:state (dict) – The Model state

Tensorboard

class torchbearer.callbacks.tensor_board.AbstractTensorBoard(log_dir='./logs', comment='torchbearer', visdom=False, visdom_params=None)[source]

TensorBoard callback which writes metrics to the given log directory. Requires the TensorboardX library for python.

Parameters:
  • log_dir (str) – The tensorboard log path for output
  • comment (str) – Descriptive comment to append to path
  • visdom (bool) – If true, log to visdom instead of tensorboard
  • visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
close_writer(log_dir=None)[source]

Decrement the reference count for a writer belonging to the given log directory (or the default writer if the directory is not given). If the reference count gets to zero, the writer will be closed and removed. :param log_dir: the (optional) directory :type log_dir: str

get_writer(log_dir=None, visdom=False, visdom_params=None)[source]

Get a SummaryWriter for the given directory (or the default writer if the directory is not given). If you are getting a SummaryWriter for a custom directory, it is your responsibility to close it using close_writer. :param log_dir: the (optional) directory :type log_dir: str :param visdom: If true, return VisdomWriter, if false return tensorboard SummaryWriter :type visdom: bool :return: the SummaryWriter or VisdomWriter

on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.tensor_board.TensorBoard(log_dir='./logs', write_graph=True, write_batch_metrics=False, batch_step_size=10, write_epoch_metrics=True, comment='torchbearer', visdom=False, visdom_params=None)[source]

TensorBoard callback which writes metrics to the given log directory. Requires the TensorboardX library for python.

Parameters:
  • log_dir (str) – The tensorboard log path for output
  • write_graph (bool) – If True, the model graph will be written using the TensorboardX library
  • write_batch_metrics (bool) – If True, batch metrics will be written
  • batch_step_size (int) – The step size to use when writing batch metrics, make this larger to reduce latency
  • write_epoch_metrics (True) – If True, metrics from the end of the epoch will be written
  • comment (str) – Descriptive comment to append to path
  • visdom (bool) – If true, log to visdom instead of tensorboard
  • visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Perform some action with the given state as context at the start of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.tensor_board.TensorBoardImages(log_dir='./logs', comment='torchbearer', name='Image', key='y_pred', write_each_epoch=True, num_images=16, nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0, visdom=False, visdom_params=None)[source]

The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the TensorboardX library and torchvision.utils.make_grid. Images are selected from the given key and saved to the given path. Full name of image sub directory will be model name + _ + comment.

Parameters:
  • log_dir (str) – The tensorboard log path for output
  • comment (str) – Descriptive comment to append to path
  • name (str) – The name of the image
  • key (str) – The key in state containing image data (tensor of size [c, w, h] or [b, c, w, h])
  • write_each_epoch (bool) – If True, write data on every epoch, else write only for the first epoch.
  • num_images (int) – The number of images to write
  • nrow – See torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid
  • padding – See torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid
  • normalize – See torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid
  • norm_range – See torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid
  • scale_each – See torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid
  • pad_value – See torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid
  • visdom (bool) – If true, log to visdom instead of tensorboard
  • visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.tensor_board.TensorBoardProjector(log_dir='./logs', comment='torchbearer', num_images=100, avg_pool_size=1, avg_data_channels=True, write_data=True, write_features=True, features_key='y_pred')[source]

The TensorBoardProjector callback is used to write images from the validation pass to Tensorboard using the TensorboardX library. Images are written to the given directory and, if required, so are associated features.

Parameters:
  • log_dir (str) – The tensorboard log path for output
  • comment (str) – Descriptive comment to append to path
  • num_images (int) – The number of images to write
  • avg_pool_size (int) – Size of the average pool to perform on the image. This is recommended to reduce the overall image sizes and improve latency
  • avg_data_channels (bool) – If True, the image data will be averaged in the channel dimension
  • write_data (bool) – If True, the raw data will be written as an embedding
  • write_features (bool) – If True, the image features will be written as an embedding
  • features_key (str) – The key in state to use for the embedding. Typically model output but can be used to show features from any layer of the model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
class torchbearer.callbacks.tensor_board.TensorBoardText(log_dir='./logs', write_epoch_metrics=True, write_batch_metrics=False, log_trial_summary=True, batch_step_size=100, comment='torchbearer', visdom=False, visdom_params=None)[source]

TensorBoard callback which writes metrics as text to the given log directory. Requires the TensorboardX library for python.

Parameters:
  • log_dir (str) – The tensorboard log path for output
  • write_epoch_metrics (True) – If True, metrics from the end of the epoch will be written
  • log_trial_string – If True logs a string summary of the Trial
  • batch_step_size (int) – The step size to use when writing batch metrics, make this larger to reduce latency
  • comment (str) – Descriptive comment to append to path
  • visdom (bool) – If true, log to visdom instead of tensorboard
  • visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Perform some action with the given state as context at the start of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
static table_formatter(string)[source]
class torchbearer.callbacks.tensor_board.VisdomParams[source]

Class to hold visdom client arguments. Modify member variables before initialising tensorboard callbacks for custom arguments. See: visdom

ENDPOINT = 'events'
ENV = 'main'
HTTP_PROXY_HOST = None
HTTP_PROXY_PORT = None
IPV6 = True
LOG_TO_FILENAME = None
PORT = 8097
RAISE_EXCEPTIONS = None
SEND = True
SERVER = 'http://localhost'
USE_INCOMING_SOCKET = True
torchbearer.callbacks.tensor_board.close_writer(log_dir, logger)[source]

Decrement the reference count for a writer belonging to a specific log directory. If the reference count gets to zero, the writer will be closed and removed.

Parameters:
  • log_dir – the log directory
  • logger – the object releasing the writer
torchbearer.callbacks.tensor_board.get_writer(log_dir, logger, visdom=False, visdom_params=None)[source]

Get the writer assigned to the given log directory. If the writer doesn’t exist it will be created, and a reference to the logger added.

Parameters:
  • log_dir – the log directory
  • logger – the object requesting the writer. That object should call close_writer when its finished
  • visdom – if true VisdomWriter is returned instead of tensorboard SummaryWriter
  • visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
Returns:

the SummaryWriter or VisdomWriter object

Early Stopping

class torchbearer.callbacks.early_stopping.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto')[source]

Callback to stop training when a monitored quantity has stopped improving.

Parameters:
  • monitor (str) – Quantity to be monitored
  • min_delta (float) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.
  • patience (int) – Number of epochs with no improvement after which training will be stopped.
  • verbose (int) – Verbosity mode, will print stopping info if verbose > 0
  • mode (str) – One of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity.
load_state_dict(state_dict)[source]

Resume this callback from the given state. Expects that this callback was constructed in the same way.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:Callback
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
state_dict()[source]

Get a dict containing the callback state.

Returns:A dict containing parameters and persistent buffers.
Return type:dict
class torchbearer.callbacks.terminate_on_nan.TerminateOnNaN(monitor='running_loss')[source]

Callback which montiors the given metric and halts training if its value is nan or inf.

Parameters:monitor (str) – The metric name to monitor
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Gradient Clipping

class torchbearer.callbacks.gradient_clipping.GradientClipping(clip_value, params=None)[source]

GradientClipping callback, which uses ‘torch.nn.utils.clip_grad_value_’ to clip the gradients of the given parameters to the given value. If params is None they will be retrieved from state.

Parameters:
  • clip_value – The maximum absolute value of the gradient
  • params – The parameters to clip or None
on_backward(state)[source]

Between the backward pass (which computes the gradients) and the step call (which updates the parameters), clip the gradient.

Parameters:state (dict) – The Model state
on_start(state)[source]

If params is None then retrieve from the model.

Parameters:state (dict) – The Model state
class torchbearer.callbacks.gradient_clipping.GradientNormClipping(max_norm, norm_type=2, params=None)[source]

GradientNormClipping callback, which uses ‘torch.nn.utils.clip_grad_norm_’ to clip the gradient norms to the given value. If params is None they will be retrieved from state.

Parameters:
  • max_norm – The max norm value
  • norm_type – The norm type to use
  • params – The parameters to clip or None
on_backward(state)[source]

Between the backward pass (which computes the gradients) and the step call (which updates the parameters), clip the gradient.

Parameters:state (dict) – The Model state
on_start(state)[source]

If params is None then retrieve from the model.

Parameters:state (dict) – The Model state

Learning Rate Schedulers

class torchbearer.callbacks.torch_scheduler.CosineAnnealingLR(T_max, eta_min=0, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch CosineAnnealingLR
class torchbearer.callbacks.torch_scheduler.ExponentialLR(gamma, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch ExponentialLR
class torchbearer.callbacks.torch_scheduler.LambdaLR(lr_lambda, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch LambdaLR
class torchbearer.callbacks.torch_scheduler.MultiStepLR(milestones, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch MultiStepLR
class torchbearer.callbacks.torch_scheduler.ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, step_on_batch=False)[source]
Parameters:monitor (str) – The quantity to monitor. (Default value = ‘val_loss’)
See:
PyTorch ReduceLROnPlateau
class torchbearer.callbacks.torch_scheduler.StepLR(step_size, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]
See:
PyTorch StepLR
class torchbearer.callbacks.torch_scheduler.TorchScheduler(scheduler_builder, monitor=None, step_on_batch=False)[source]
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Perform some action with the given state as context at the start of the training loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.

Weight Decay

class torchbearer.callbacks.weight_decay.L1WeightDecay(rate=0.0005, params=None)[source]

WeightDecay callback which uses an L1 norm with the given rate and parameters. If params is None (default) then the parameters will be retrieved from the model.

Parameters:
  • rate (float) – The decay rate
  • params (list) – The parameters to use (or None)
class torchbearer.callbacks.weight_decay.L2WeightDecay(rate=0.0005, params=None)[source]

WeightDecay callback which uses an L2 norm with the given rate and parameters. If params is None (default) then the parameters will be retrieved from the model.

Parameters:
  • rate (float) – The decay rate
  • params (list) – The parameters to use (or None)
class torchbearer.callbacks.weight_decay.WeightDecay(rate=0.0005, p=2, params=None)[source]

Create a WeightDecay callback which uses the given norm on the given parameters and with the given decay rate. If params is None (default) then the parameters will be retrieved from the model.

Parameters:
  • rate (float) – The decay rate
  • p (int) – The norm level
  • params (list) – The parameters to use (or None)
on_criterion(state)[source]

Calculate the decay term and add to state[‘loss’].

Parameters:state (dict) – The Model state
on_start(state)[source]

Retrieve params from state[‘model’] if required.

Parameters:state (dict) – The Model state

Decorators

class torchbearer.callbacks.decorators.LambdaCallback(func)[source]
on_lambda(state)[source]
torchbearer.callbacks.decorators.add_to_loss(func)[source]

The add_to_loss() decorator is used to initialise a Callback with the value returned from func being added to the loss

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback which adds the returned value from func to the loss
Return type:Callback
torchbearer.callbacks.decorators.bind_to(target)[source]
torchbearer.callbacks.decorators.on_backward(func)[source]

The on_backward() decorator is used to initialise a Callback with on_backward() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_backward() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_criterion(func)[source]

The on_criterion() decorator is used to initialise a Callback with on_criterion() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_criterion() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_criterion_validation(func)[source]

The on_criterion_validation() decorator is used to initialise a Callback with on_criterion_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_criterion_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end(func)[source]

The on_end() decorator is used to initialise a Callback with on_end() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end_epoch(func)[source]

The on_end_epoch() decorator is used to initialise a Callback with on_end_epoch() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end_epoch() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end_training(func)[source]

The on_end_training() decorator is used to initialise a Callback with on_end_training() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end_training() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_end_validation(func)[source]

The on_end_validation() decorator is used to initialise a Callback with on_end_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_end_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_forward(func)[source]

The on_forward() decorator is used to initialise a Callback with on_forward() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_forward() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_forward_validation(func)[source]

The on_forward_validation() decorator is used to initialise a Callback with on_forward_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_forward_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_sample(func)[source]

The on_sample() decorator is used to initialise a Callback with on_sample() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_sample() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_sample_validation(func)[source]

The on_sample_validation() decorator is used to initialise a Callback with on_sample_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_sample_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start(func)[source]

The on_start() decorator is used to initialise a Callback with on_start() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with on_start() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start_epoch(func)[source]

The on_start_epoch() decorator is used to initialise a Callback with on_start_epoch() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with on_start_epoch() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start_training(func)[source]

The on_start_training() decorator is used to initialise a Callback with on_start_training() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_start_training() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_start_validation(func)[source]

The on_start_validation() decorator is used to initialise a Callback with on_start_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_start_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_step_training(func)[source]

The on_step_training() decorator is used to initialise a Callback with on_step_training() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_step_training() calling func
Return type:Callback
torchbearer.callbacks.decorators.on_step_validation(func)[source]

The on_step_validation() decorator is used to initialise a Callback with on_step_validation() calling the decorated function

Parameters:func (function) – The function(state) to decorate
Returns:Initialised callback with Callback.on_step_validation() calling func
Return type:Callback
torchbearer.callbacks.decorators.once(fcn)[source]

Decorator to fire a callback once in the entire fitting procedure. :param fcn: the torchbearer callback function to decorate. :return: the decorator

torchbearer.callbacks.decorators.once_per_epoch(fcn)[source]

Decorator to fire a callback once (on the first call) in any given epoch. :param fcn: the torchbearer callback function to decorate. :return: the decorator

torchbearer.callbacks.decorators.only_if(condition_expr)[source]

Decorator to fire a callback only if the given conditional expression function returns True. :param condition_expr: a function/lambda that must evaluate to true for the decorated torchbearer callback to be called. The state object passed to the callback will be passed as an argument to the condition function. :return: the decorator