torchbearer.callbacks¶
Base Classes¶
-
class
torchbearer.bases.
Callback
[source]¶ Base callback class.
Note
All callbacks should override this class.
-
state_dict
()[source]¶ Get a dict containing the callback state.
Returns: A dict containing parameters and persistent buffers. Return type: dict
-
load_state_dict
(state_dict)[source]¶ Resume this callback from the given state. Expects that this callback was constructed in the same way.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: Callback
-
on_init
(state)[source]¶ Perform some action with the given state as context at the init of a trial instance
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start_epoch
(state)[source]¶ Perform some action with the given state as context at the start of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start_training
(state)[source]¶ Perform some action with the given state as context at the start of the training loop.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_forward
(state)[source]¶ Perform some action with the given state as context after the forward pass (model output) has been completed.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_criterion
(state)[source]¶ Perform some action with the given state as context after the criterion has been evaluated.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_backward
(state)[source]¶ Perform some action with the given state as context after backward has been called on the loss.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_step_training
(state)[source]¶ Perform some action with the given state as context after step has been called on the optimiser.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_training
(state)[source]¶ Perform some action with the given state as context after the training loop has completed.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start_validation
(state)[source]¶ Perform some action with the given state as context at the start of the validation loop.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_sample_validation
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the validation generator.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_forward_validation
(state)[source]¶ Perform some action with the given state as context after the forward pass (model output) has been completed with the validation data.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_criterion_validation
(state)[source]¶ Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_step_validation
(state)[source]¶ Perform some action with the given state as context at the end of each validation step.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_validation
(state)[source]¶ Perform some action with the given state as context at the end of the validation loop.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_checkpoint
(state)[source]¶ Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.
Parameters: state (dict) – The current state dict of the Trial
.
-
-
class
torchbearer.callbacks.callbacks.
CallbackList
(callback_list)[source]¶ The
CallbackList
class is a wrapper for a list of callbacks which acts as a singleCallback
and internally calls eachCallback
in the given list in turn.Parameters: callback_list (list) – The list of callbacks to be wrapped. If the list contains a CallbackList
, this will be unwrapped.-
CALLBACK_STATES
= 'callback_states'¶
-
CALLBACK_TYPES
= 'callback_types'¶
-
state_dict
()[source]¶ Get a dict containing all of the callback states.
Returns: A dict containing parameters and persistent buffers. Return type: dict
-
load_state_dict
(state_dict)[source]¶ Resume this callback list from the given state. Callbacks must be given in the same order for this to work.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: CallbackList
-
on_init
(state)[source]¶ Call on_init on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_start
(state)[source]¶ Call on_start on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_start_epoch
(state)[source]¶ Call on_start_epoch on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_start_training
(state)[source]¶ Call on_start_training on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_sample
(state)[source]¶ Call on_sample on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_forward
(state)[source]¶ Call on_forward on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_criterion
(state)[source]¶ Call on_criterion on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_backward
(state)[source]¶ Call on_backward on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_step_training
(state)[source]¶ Call on_step_training on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_end_training
(state)[source]¶ Call on_end_training on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_start_validation
(state)[source]¶ Call on_start_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_sample_validation
(state)[source]¶ Call on_sample_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_forward_validation
(state)[source]¶ Call on_forward_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_criterion_validation
(state)[source]¶ Call on_criterion_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_step_validation
(state)[source]¶ Call on_step_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_end_validation
(state)[source]¶ Call on_end_validation on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
on_end_epoch
(state)[source]¶ Call on_end_epoch on each callback in turn with the given state.
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
Imaging¶
Main Classes¶
-
class
torchbearer.callbacks.imaging.imaging.
CachingImagingCallback
(key=x, transform=None, num_images=16)[source]¶ The
CachingImagingCallback
is anImagingCallback
which caches batches of images from the given state key up to the required amount before passing this along with state to the implementing class, once per epoch.Parameters:
-
class
torchbearer.callbacks.imaging.imaging.
FromState
(key, transform=None, decorator=None)[source]¶ The
FromState
callback is anImagingCallback
which retrieves and image from state when called. The number of times the function is called can be controlled with a provided decorator (once_per_epoch, only_if etc.)Parameters: - key (StateKey) – The
StateKey
containing the image (tensor of size [c, w, h]) - transform (callable, optional) – A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output.
- decorator – A function which will be used to wrap the callback function. once_per_epoch by default
- key (StateKey) – The
-
class
torchbearer.callbacks.imaging.imaging.
ImagingCallback
(transform=None)[source]¶ The
ImagingCallback
provides a generic interface for callbacks which yield images that should be sent to a file, tensorboard, visdom etc. without needing bespoke code. This allows the user to easily define custom visualisations by only writing the code to produce the image.Parameters: transform (callable, optional) – A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. -
cache
(num_images)[source]¶ Cache images before they are passed to handlers. Once per epoch, a single cache will be returned, containing the first num_images to be returned.
Parameters: num_images (int) – The number of images to cache Returns: self Return type: ImagingCallback
-
make_grid
(nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0)[source]¶ Use torchvision.utils.make_grid to make a grid of the images being returned by this callback. Recommended for use alongside cache.
Parameters: - nrow – See torchvision.utils.make_grid
- padding –
- normalize –
- norm_range –
- scale_each –
- pad_value –
Returns: self
Return type:
-
on_test
()[source]¶ Process this callback for test batches
Returns: self Return type: ImagingCallback
-
on_train
()[source]¶ Process this callback for training batches
Returns: self Return type: ImagingCallback
-
on_val
()[source]¶ Process this callback for validation batches
Returns: self Return type: ImagingCallback
-
to_file
(filename, index=None)[source]¶ Send images from this callback to the given file
Parameters: - filename (str) – the filename to store the image to
- index (int or list or None) – if not None, only apply the handler on this index / list of indices
Returns: self
Return type:
-
to_pyplot
(index=None)[source]¶ Show images from this callback with pyplot
Parameters: index (int or list or None) – if not None, only apply the handler on this index / list of indices Returns: self Return type: ImagingCallback
-
to_state
(keys, index=None)[source]¶ Put images from this callback in state with the given key
Parameters: Returns: self
Return type:
-
to_tensorboard
(name='Image', log_dir='./logs', comment='torchbearer', index=None)[source]¶ Direct images from this callback to tensorboard with the given parameters
Parameters: - name (str) – The name of the image
- log_dir (str) – The tensorboard log path for output
- comment (str) – Descriptive comment to append to path
- index (int or list or None) – if not None, only apply the handler on this index / list of indices
Returns: self
Return type:
-
to_visdom
(name='Image', log_dir='./logs', comment='torchbearer', visdom_params=None, index=None)[source]¶ Direct images from this callback to visdom with the given parameters
Parameters: - name (str) – The name of the image
- log_dir (str) – The visdom log path for output
- comment (str) – Descriptive comment to append to path
- visdom_params (
VisdomParams
) – Visdom parameter settings object, uses default if None - index (int or list or None) – if not None, only apply the handler on this index / list of indices
Returns: self
Return type:
-
with_handler
(handler, index=None)[source]¶ Append the given output handler to the list of handlers
Parameters: - handler – A function of image and state which stores the given image in some way
- index (int or list or None) – if not None, only apply the handler on this index / list of indices
Returns: self
Return type:
-
-
class
torchbearer.callbacks.imaging.imaging.
MakeGrid
(key=x, transform=None, num_images=16, nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0)[source]¶ The
MakeGrid
callback is aCachingImagingCallback
which calls make grid on the cache with the provided parameters.Parameters: - key (StateKey) – The
StateKey
containing image data (tensor of size [b, c, w, h]) - transform (callable, optional) – A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output.
- num_images – The number of images to cache
- nrow –
- padding –
- normalize –
- norm_range –
- scale_each –
- pad_value –
- key (StateKey) – The
Deep Inside Convolutional Networks¶
-
class
torchbearer.callbacks.imaging.inside_cnns.
ClassAppearanceModel
(nclasses, input_size, optimizer_factory=<function ClassAppearanceModel.<lambda>>, steps=256, logit_key=y_pred, target=-10, decay=0.01, verbose=0, in_transform=None, transform=None)[source]¶ The
ClassAppearanceModel
callback implements Figure 1 from Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps. This is a simple gradient ascent on an image (initialised to zero) with a sum-squares regularizer. Internally this creates a newTrial
instance which then performs the optimization.@article{simonyan2013deep, title={Deep inside convolutional networks: Visualising image classification models and saliency maps}, author={Simonyan, Karen and Vedaldi, Andrea and Zisserman, Andrew}, journal={arXiv preprint arXiv:1312.6034}, year={2013} }
Parameters: - nclasses (int) – The number of output classes
- input_size (tuple) – The size to use for the input image
- optimizer_factory – A function of parameters which returns an optimizer to use
- logit_key (StateKey) –
StateKey
storing the class logits - target (int) – Target class for the optimisation or RANDOM
- steps (int) – Number of optimisation steps to take
- decay (float) – Lambda for the L2 decay on the image
- verbose (int) – Verbosity level to pass to the internal
Trial
instance - transform (callable, optional) – A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output
-
on_batch
(state)¶
-
torchbearer.callbacks.imaging.inside_cnns.
RANDOM
= -10¶ Flag that when passed as the target chosses a random target
Model Checkpointers¶
-
class
torchbearer.callbacks.checkpointers.
Best
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', save_model_params_only=False, monitor='val_loss', mode='auto', period=1, min_delta=0, pickle_module=<sphinx.ext.autodoc.importer._MockObject object>, pickle_protocol=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Model checkpointer which saves the best model according to the given configurations. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename.
Example:
>>> from torchbearer.callbacks import Best >>> from torchbearer import Trial >>> import torch # Example Trial (without optimiser or loss criterion) which uses this checkpointer >>> model = torch.nn.Linear(1,1) >>> checkpoint = Best('my_path.pt', monitor='val_acc', mode='max') >>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Parameters: - filepath (str) – Path to save the model file
- save_model_params_only (bool) – If save_model_params_only=True, only model parameters will be saved so that the results can be loaded into a PyTorch nn.Module. The other option, save_model_params_only=False, should be used only if the results will be loaded into a Torchbearer Trial object later.
- monitor (str) – Quantity to monitor
- mode (str) – One of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
- period (int) – Interval (number of epochs) between checkpoints
- min_delta (float) – If save_best_only=True, this is the minimum improvement required to trigger a save
- pickle_module (module) – The pickle module to use, default is ‘torch.serialization.pickle’
- pickle_protocol (int) – The pickle protocol to use, default is ‘torch.serialization.DEFAULT_PROTOCOL’
- State Requirements:
torchbearer.state.MODEL
: Model should have the state_dict methodtorchbearer.state.METRICS
: Metrics dictionary should exist, with the monitor key populatedtorchbearer.state.SELF
: Self should be thetorchbearer.Trial
which is running this callback
-
load_state_dict
(state_dict)[source]¶ Resume this callback from the given state. Expects that this callback was constructed in the same way.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: Callback
-
on_checkpoint
(state)[source]¶ Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.
Parameters: state (dict) – The current state dict of the Trial
.
-
class
torchbearer.callbacks.checkpointers.
Interval
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', save_model_params_only=False, period=1, on_batch=False, pickle_module=<sphinx.ext.autodoc.importer._MockObject object>, pickle_protocol=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Model checkpointer which which saves the model every ‘period’ epochs to the given filepath. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename.
Example:
>>> from torchbearer.callbacks import Interval >>> from torchbearer import Trial >>> import torch # Example Trial (without optimiser or loss criterion) which uses this checkpointer >>> model = torch.nn.Linear(1,1) >>> checkpoint = Interval('my_path.pt', period=100, on_batch=True) >>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Parameters: - filepath (str) – Path to save the model file
- save_model_params_only (bool) – If save_model_params_only=True, only model parameters will be saved so that the results can be loaded into a PyTorch nn.Module. The other option, save_model_params_only=False, should be used only if the results will be loaded into a Torchbearer Trial object later.
- period (int) – Interval (number of steps) between checkpoints
- on_batch (bool) – If true step each batch, if false step each epoch.
- period – Interval (number of epochs) between checkpoints
- pickle_module (module) – The pickle module to use, default is ‘torch.serialization.pickle’
- pickle_protocol (int) – The pickle protocol to use, default is ‘torch.serialization.DEFAULT_PROTOCOL’
- State Requirements:
torchbearer.state.MODEL
: Model should have the state_dict methodtorchbearer.state.METRICS
: Metrics dictionary should existtorchbearer.state.SELF
: Self should be thetorchbearer.Trial
which is running this callback
-
load_state_dict
(state_dict)[source]¶ Resume this callback from the given state. Expects that this callback was constructed in the same way.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: Callback
-
on_checkpoint
(state)[source]¶ Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.
Parameters: state (dict) – The current state dict of the Trial
.
-
torchbearer.callbacks.checkpointers.
ModelCheckpoint
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', save_model_params_only=False, monitor='val_loss', save_best_only=False, mode='auto', period=1, min_delta=0)[source]¶ Save the model after every epoch. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The torch
Trial
will be saved to filename.Example:
>>> from torchbearer.callbacks import ModelCheckpoint >>> from torchbearer import Trial >>> import torch # Example Trial (without optimiser or loss criterion) which uses this checkpointer >>> model = torch.nn.Linear(1,1) >>> checkpoint = ModelCheckpoint('my_path.pt', monitor='val_acc', mode='max') >>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Parameters: - filepath (str) – Path to save the model file
- save_model_params_only (bool) – If save_model_params_only=True, only model parameters will be saved so that the results can be loaded into a PyTorch nn.Module. The other option, save_model_params_only=False, should be used only if the results will be loaded into a Torchbearer Trial object later.
- monitor (str) – Quantity to monitor
- save_best_only (bool) – If save_best_only=True, the latest best model according to the quantity monitored will not be overwritten
- mode (str) – One of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.
- period (int) – Interval (number of epochs) between checkpoints
- min_delta (float) – If save_best_only=True, this is the minimum improvement required to trigger a save
- State Requirements:
torchbearer.state.MODEL
: Model should have the state_dict methodtorchbearer.state.METRICS
: Metrics dictionary should existtorchbearer.state.SELF
: Self should be thetorchbearer.Trial
which is running this callback
-
class
torchbearer.callbacks.checkpointers.
MostRecent
(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', save_model_params_only=False, pickle_module=<sphinx.ext.autodoc.importer._MockObject object>, pickle_protocol=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Model checkpointer which saves the most recent model to a given filepath. filepath can contain named formatting options, which will be filled any values from state. For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}, then the model checkpoints will be saved with the epoch number and the validation loss in the filename.
Example:
>>> from torchbearer.callbacks import MostRecent >>> from torchbearer import Trial >>> import torch # Example Trial (without optimiser or loss criterion) which uses this checkpointer >>> model = torch.nn.Linear(1,1) >>> checkpoint = MostRecent('my_path.pt') >>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Parameters: - filepath (str) – Path to save the model file
- save_model_params_only (bool) – If save_model_params_only=True, only model parameters will be saved so that the results can be loaded into a PyTorch nn.Module. The other option, save_model_params_only=False, should be used only if the results will be loaded into a Torchbearer Trial object later.
- pickle_module (module) – The pickle module to use, default is ‘torch.serialization.pickle’
- pickle_protocol (int) – The pickle protocol to use, default is ‘torch.serialization.DEFAULT_PROTOCOL’
- State Requirements:
torchbearer.state.MODEL
: Model should have the state_dict methodtorchbearer.state.METRICS
: Metrics dictionary should existtorchbearer.state.SELF
: Self should be thetorchbearer.Trial
which is running this callback
-
on_checkpoint
(state)[source]¶ Perform some action with the state after all other callbacks have completed at the end of an epoch and the history has been updated. Should only be used for taking checkpoints or snapshots and will only be called by the run method of Trial.
Parameters: state (dict) – The current state dict of the Trial
.
Logging¶
-
class
torchbearer.callbacks.csv_logger.
CSVLogger
(filename, separator=', ', batch_granularity=False, write_header=True, append=False)[source]¶ Callback to log metrics to a given csv file.
Example:
>>> from torchbearer.callbacks import CSVLogger >>> from torchbearer import Trial >>> import torch # Example Trial (without optimiser or loss criterion) which writes metrics to a csv file appending to previous content >>> logger = CSVLogger('my_path.pt', separator=',', append=True) >>> trial = Trial(None, callbacks=[logger], metrics=['acc'])
Parameters: - filename (str) – The name of the file to output to
- separator (str) – The delimiter to use (e.g. comma, tab etc.)
- batch_granularity (bool) – If True, write on each batch, else on each epoch
- write_header (bool) – If True, write the CSV header at the beginning of training
- append (bool) – If True, append to the file instead of replacing it
- State Requirements:
torchbearer.state.EPOCH
: State should have the current epoch storedtorchbearer.state.METRICS
: Metrics dictionary should existtorchbearer.state.BATCH
: State should have the current batch stored if using batch_granularity
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict) – The current state dict of the Trial
.
-
class
torchbearer.callbacks.printer.
ConsolePrinter
(validation_label_letter='v', precision=4)[source]¶ The ConsolePrinter callback simply outputs the training metrics to the console.
Example:
>>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import ConsolePrinter # Example Trial which forgoes the usual printer for a console printer >>> printer = ConsolePrinter() >>> trial = Trial(None, callbacks=[printer], verbose=0).for_steps(1).run() 0/1(t):
Parameters: - validation_label_letter (str) – This is the letter displayed after the epoch number indicating the current phase of training
- precision (int) – Precision of the number format in decimal places
- State Requirements:
torchbearer.state.EPOCH
: The current epoch numbertorchbearer.state.MAX_EPOCHS
: The total number of epochs for this runtorchbearer.state.BATCH
: The current batch / iteration numbertorchbearer.state.STEPS
: The total number of steps / batches / iterations for this epochtorchbearer.state.METRICS
: The metrics dict to print
-
on_end_training
(state)[source]¶ Perform some action with the given state as context after the training loop has completed.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_validation
(state)[source]¶ Perform some action with the given state as context at the end of the validation loop.
Parameters: state (dict) – The current state dict of the Trial
.
-
class
torchbearer.callbacks.printer.
Tqdm
(tqdm_module=None, validation_label_letter='v', precision=4, on_epoch=False, **tqdm_args)[source]¶ The Tqdm callback outputs the progress and metrics for training and validation loops to the console using TQDM. The given key is used to label validation output.
Example:
>>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import Tqdm # Example Trial which forgoes the usual printer for a customised tqdm printer. >>> printer = Tqdm(precision=8) # Note that outputs are written to stderr, not stdout as shown in this example >>> trial = Trial(None, callbacks=[printer], verbose=0).for_steps(1).run(1) 0/1(t): 100%|...| 1/1 [00:00<00:00, 29.40it/s]
Parameters: - tqdm_module – The tqdm module to use. If none, defaults to tqdm or tqdm_notebook if in notebook
- validation_label_letter (str) – The letter to use for validation outputs.
- precision (int) – Precision of the number format in decimal places
- on_epoch (bool) – If True, output a single progress bar which tracks epochs
- tqdm_args – Any extra keyword args provided here will be passed through to the tqdm module constructor. See github.com/tqdm/tqdm#parameters for more details.
- State Requirements:
torchbearer.state.EPOCH
: The current epoch numbertorchbearer.state.MAX_EPOCHS
: The total number of epochs for this runtorchbearer.state.STEPS
: The total number of steps / batches / iterations for this epochtorchbearer.state.METRICS
: The metrics dict to printtorchbearer.state.HISTORY
: The history of theTrial
object
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_training
(state)[source]¶ Update the bar with the terminal training metrics and then close.
Parameters: state (dict) – The Trial
state
-
on_end_validation
(state)[source]¶ Update the bar with the terminal validation metrics and then close.
Parameters: state (dict) – The Trial
state
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start_training
(state)[source]¶ Initialise the TQDM bar for this training phase.
Parameters: state (dict) – The Trial
state
-
on_start_validation
(state)[source]¶ Initialise the TQDM bar for this validation phase.
Parameters: state (dict) – The Trial
state
Tensorboard, Visdom and Others¶
-
class
torchbearer.callbacks.tensor_board.
AbstractTensorBoard
(log_dir='./logs', comment='torchbearer', visdom=False, visdom_params=None)[source]¶ TensorBoard callback which writes metrics to the given log directory. Requires the TensorboardX library for python.
Parameters: - log_dir (str) – The tensorboard log path for output
- comment (str) – Descriptive comment to append to path
- visdom (bool) – If true, log to visdom instead of tensorboard
- visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
- State Requirements:
torchbearer.state.MODEL
: PyTorch model
-
static
add_metric
(add_fn, tag, metric, *args, **kwargs)[source]¶ Static method that recurses through metric until the add_fn can be applied. Useful when metric is an iterable of tensors so that the tensors can all be passed to an add_fn such as writer.add_scalar. For example, if passed metric as [[A, B], [C, ], D, {‘E’: E}] then add_fn would be called on A, B, C, D and E and the respective tags (with base tag ‘met’) would be: met_0_0, met_0_1, met_1_0, met_2, met_E. Throws a warning if add_fn fails to parse a metric.
Parameters: - add_fn – Function to be called to log a metric, e.g. SummaryWriter.add_scalar
- tag – Tag under which to log the metric
- metric – Iterable of metrics.
- *args – Args for add_fn
- **kwargs – Keyword args for add_fn
Returns:
-
close_writer
(log_dir=None)[source]¶ Decrement the reference count for a writer belonging to the given log directory (or the default writer if the directory is not given). If the reference count gets to zero, the writer will be closed and removed.
Parameters: log_dir (str) – the (optional) directory
-
get_writer
(log_dir=None, visdom=False, visdom_params=None)[source]¶ Get a SummaryWriter for the given directory (or the default writer if the directory is not given). If you are getting a SummaryWriter for a custom directory, it is your responsibility to close it using close_writer.
Parameters: - log_dir (str) – the (optional) directory
- visdom (bool) – If true, return VisdomWriter, if false return tensorboard SummaryWriter
- visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
Returns: the SummaryWriter or VisdomWriter
-
class
torchbearer.callbacks.tensor_board.
TensorBoard
(log_dir='./logs', write_graph=True, write_batch_metrics=False, batch_step_size=10, write_epoch_metrics=True, comment='torchbearer', visdom=False, visdom_params=None)[source]¶ TensorBoard callback which writes metrics to the given log directory. Requires the TensorboardX library for python.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import TensorBoard >>> import datetime >>> current_time = datetime.now().strftime('%b%d_%H-%M-%S') # Callback that will log to tensorboard under "(model name)_(current time)" >>> tb = TensorBoard(log_dir='./logs', write_graph=False, comment=current_time) # Trial that will run the callback and log accuracy and loss metrics >>> t = Trial(None, callbacks=[tb], metrics=['acc', 'loss'])
Parameters: - log_dir (str) – The tensorboard log path for output
- write_graph (bool) – If True, the model graph will be written using the TensorboardX library
- write_batch_metrics (bool) – If True, batch metrics will be written
- batch_step_size (int) – The step size to use when writing batch metrics, make this larger to reduce latency
- write_epoch_metrics (bool) – If True, metrics from the end of the epoch will be written
- comment (str) – Descriptive comment to append to path
- visdom (bool) – If true, log to visdom instead of tensorboard
- visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
- State Requirements:
torchbearer.state.MODEL
: PyTorch modeltorchbearer.state.EPOCH
: State should have the current epoch storedtorchbearer.state.X
: State should have the current data stored if a model graph is to be builttorchbearer.state.BATCH
: State should have the current batch number stored if logging batch metricstorchbearer.state.TRAIN_STEPS
: State should have the number of training steps storedtorchbearer.state.METRICS
: State should have a dictionary of metrics stored
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start_epoch
(state)[source]¶ Perform some action with the given state as context at the start of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
class
torchbearer.callbacks.tensor_board.
TensorBoardImages
(log_dir='./logs', comment='torchbearer', name='Image', key=y_pred, write_each_epoch=True, num_images=16, nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0, visdom=False, visdom_params=None)[source]¶ The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the TensorboardX library and torchvision.utils.make_grid (requires torchvision). Images are selected from the given key and saved to the given path. Full name of image sub directory will be model name + _ + comment.
Example:
>>> from torchbearer import Trial, state_key >>> from torchbearer.callbacks import TensorBoardImages >>> import datetime >>> current_time = datetime.now().strftime('%b%d_%H-%M-%S') >>> IMAGE_KEY = state_key('image_key') >>> # Callback that will log to tensorboard under "(model name)_(current time)" >>> tb = TensorBoardImages(comment=current_time, name='test_image', key=IMAGE_KEY) >>> # Trial that will run log to tensorboard images stored under IMAGE_KEY >>> t = Trial(None, callbacks=[tb], metrics=['acc', 'loss'])
Parameters: - log_dir (str) – The tensorboard log path for output
- comment (str) – Descriptive comment to append to path
- name (str) – The name of the image
- key (StateKey) – The key in state containing image data (tensor of size [c, w, h] or [b, c, w, h])
- write_each_epoch (bool) – If True, write data on every epoch, else write only for the first epoch.
- num_images (int) – The number of images to write
- nrow –
- padding –
- normalize –
- norm_range –
- scale_each –
- pad_value –
- visdom (bool) – If true, log to visdom instead of tensorboard
- visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
- State Requirements:
torchbearer.state.EPOCH
: State should have the current epoch stored- key: State should have images stored under the given state key
-
class
torchbearer.callbacks.tensor_board.
TensorBoardProjector
(log_dir='./logs', comment='torchbearer', num_images=100, avg_pool_size=1, avg_data_channels=True, write_data=True, write_features=True, features_key=y_pred)[source]¶ The TensorBoardProjector callback is used to write images from the validation pass to Tensorboard using the TensorboardX library. Images are written to the given directory and, if required, so are associated features.
Parameters: - log_dir (str) – The tensorboard log path for output
- comment (str) – Descriptive comment to append to path
- num_images (int) – The number of images to write
- avg_pool_size (int) – Size of the average pool to perform on the image. This is recommended to reduce the overall image sizes and improve latency
- avg_data_channels (bool) – If True, the image data will be averaged in the channel dimension
- write_data (bool) – If True, the raw data will be written as an embedding
- write_features (bool) – If True, the image features will be written as an embedding
- features_key (StateKey) – The key in state to use for the embedding. Typically model output but can be used to show features from any layer of the model.
- State Requirements:
torchbearer.state.EPOCH
: State should have the current epoch storedtorchbearer.state.X
: State should have the current data storedtorchbearer.state.Y_TRUE
: State should have the current targets stored
-
class
torchbearer.callbacks.tensor_board.
TensorBoardText
(log_dir='./logs', write_epoch_metrics=True, write_batch_metrics=False, log_trial_summary=True, batch_step_size=100, comment='torchbearer', visdom=False, visdom_params=None)[source]¶ TensorBoard callback which writes metrics as text to the given log directory. Requires the TensorboardX library for python.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import TensorBoardText >>> import datetime >>> current_time = datetime.now().strftime('%b%d_%H-%M-%S') # Callback that will log to tensorboard under "(model name)_(current time)" >>> tb = TensorBoardText(comment=current_time) # Trial that will run the callback and log accuracy and loss metrics as text to tensorboard >>> t = Trial(None, callbacks=[tb], metrics=['acc', 'loss'])
Parameters: - log_dir (str) – The tensorboard log path for output
- write_epoch_metrics (bool) – If True, metrics from the end of the epoch will be written
- log_trial_summary (bool) – If True logs a string summary of the Trial
- batch_step_size (int) – The step size to use when writing batch metrics, make this larger to reduce latency
- comment (str) – Descriptive comment to append to path
- visdom (bool) – If true, log to visdom instead of tensorboard
- visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
- State Requirements:
torchbearer.state.SELF
: Thetorchbearer.Trial
running this callbacktorchbearer.state.EPOCH
: State should have the current epoch storedtorchbearer.state.BATCH
: State should have the current batch number stored if logging batch metricstorchbearer.state.METRICS
: State should have a dictionary of metrics stored
-
on_end
(state)[source]¶ Perform some action with the given state as context at the end of the model fitting.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start_epoch
(state)[source]¶ Perform some action with the given state as context at the start of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
class
torchbearer.callbacks.tensor_board.
VisdomParams
[source]¶ Class to hold visdom client arguments. Modify member variables before initialising tensorboard callbacks for custom arguments. See: visdom
-
ENDPOINT
= 'events'¶
-
ENV
= 'main'¶
-
HTTP_PROXY_HOST
= None¶
-
HTTP_PROXY_PORT
= None¶
-
IPV6
= True¶
-
LOG_TO_FILENAME
= None¶
-
PORT
= 8097¶
-
RAISE_EXCEPTIONS
= None¶
-
SEND
= True¶
-
SERVER
= 'http://localhost'¶
-
USE_INCOMING_SOCKET
= True¶
-
-
torchbearer.callbacks.tensor_board.
close_writer
(log_dir, logger)[source]¶ Decrement the reference count for a writer belonging to a specific log directory. If the reference count gets to zero, the writer will be closed and removed.
Parameters: - log_dir (str) – the log directory
- logger – the object releasing the writer
-
torchbearer.callbacks.tensor_board.
get_writer
(log_dir, logger, visdom=False, visdom_params=None)[source]¶ Get the writer assigned to the given log directory. If the writer doesn’t exist it will be created, and a reference to the logger added.
Parameters: - log_dir (str) – the log directory
- logger – the object requesting the writer. That object should call close_writer when its finished
- visdom (bool) – if true VisdomWriter is returned instead of tensorboard SummaryWriter
- visdom_params (VisdomParams) – Visdom parameter settings object, uses default if None
Returns: the SummaryWriter or VisdomWriter object
-
class
torchbearer.callbacks.live_loss_plot.
LiveLossPlot
(on_batch=False, batch_step_size=10, on_epoch=True, draw_once=False, **kwargs)[source]¶ Callback to write metrics to LiveLossPlot, a library for visualisation in notebooks
Example:
>>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import LiveLossPlot # Example Trial which clips all model gradients norms at 2 under the L1 norm. >>> model = torch.nn.Linear(1,1) >>> live_loss_plot = LiveLossPlot() >>> trial = Trial(model, callbacks=[live_loss_plot], metrics=['acc'])
Parameters: - on_batch (bool) – If True, batch metrics will be logged. Else batch metrics will not be logged
- batch_step_size (int) – The number of batches between logging metrics
- on_epoch (bool) – If True, epoch metrics will be logged every epoch. Else epoch metrics will not be logged
- draw_once (bool) – If True, draw the plot only at the end of training. Else draw every time metrics are logged
- kwargs – Keyword arguments for livelossplot.PlotLosses
- State Requirements:
torchbearer.state.METRICS
: Metrics should be a dict containing the metrics to be plottedtorchbearer.state.BATCH
: Batch should be the current batch or iteration number in the epoch
Early Stopping¶
-
class
torchbearer.callbacks.early_stopping.
EarlyStopping
(monitor='val_loss', min_delta=0, patience=0, mode='auto', step_on_batch=False)[source]¶ Callback to stop training when a monitored quantity has stopped improving.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import EarlyStopping # Example Trial which does early stopping if the validation accuracy drops below the max seen for 5 epochs in a row >>> stopping = EarlyStopping(monitor='val_acc', patience=5, mode='max') >>> trial = Trial(None, callbacks=[stopping], metrics=['acc'])
Parameters: - monitor (str) – Name of quantity in metrics to be monitored
- min_delta (float) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.
- patience (int) – Number of epochs with no improvement after which training will be stopped.
- mode (str) – One of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity.
- State Requirements:
torchbearer.state.METRICS
: Metrics should be a dict containing the given monitor key as a minimum
-
load_state_dict
(state_dict)[source]¶ Resume this callback from the given state. Expects that this callback was constructed in the same way.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: Callback
-
on_end_epoch
(state)¶
-
on_step_training
(state)¶
-
class
torchbearer.callbacks.terminate_on_nan.
TerminateOnNaN
(monitor='running_loss')[source]¶ Callback which montiors the given metric and halts training if its value is nan or inf.
Example:
>>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import TerminateOnNaN # Example Trial which terminates on a NaN, forced by a separate callback. Terminates on the 11th batch since the running loss only updates every 10 iterations. >>> term = TerminateOnNaN(monitor='running_loss') >>> @torchbearer.callbacks.on_criterion ... def force_terminate(state): ... if state[torchbearer.BATCH] == 5: ... state[torchbearer.LOSS] = state[torchbearer.LOSS] * torch.Tensor([float('NaN')]) >>> trial = Trial(None, callbacks=[term, force_terminate], metrics=['loss'], verbose=2).for_steps(30).run(1) Invalid running_loss, terminating
Parameters: monitor (str) – The name of the metric to monitor - State Requirements:
torchbearer.state.METRICS
: Metrics should be a dict containing at least the key monitor
-
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
Gradient Clipping¶
-
class
torchbearer.callbacks.gradient_clipping.
GradientClipping
(clip_value, params=None)[source]¶ GradientClipping callback, which uses ‘torch.nn.utils.clip_grad_value_’ to clip the gradients of the given parameters to the given value. If params is None they will be retrieved from state.
Example:
>>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import GradientClipping # Example Trial which clips all model gradients at 2 under the L1 norm. >>> model = torch.nn.Linear(1,1) >>> clip = GradientNormClipping(2, 1) >>> trial = Trial(model, callbacks=[clip], metrics=['acc'])
Parameters: - clip_value (float or int) – maximum allowed value of the gradients The gradients are clipped in the range [-clip_value, clip_value]
- params (Iterable[Tensor] or Tensor, optional) – an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state
- State Requirements:
torchbearer.state.MODEL
: Model should have the parameters method
-
class
torchbearer.callbacks.gradient_clipping.
GradientNormClipping
(max_norm, norm_type=2, params=None)[source]¶ GradientNormClipping callback, which uses ‘torch.nn.utils.clip_grad_norm_’ to clip the gradient norms to the given value. If params is None they will be retrieved from state.
Example:
>>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import GradientNormClipping # Example Trial which clips all model gradients norms at 2 under the L1 norm. >>> model = torch.nn.Linear(1,1) >>> clip = GradientNormClipping(2, 1) >>> trial = Trial(model, callbacks=[clip], metrics=['acc'])
Parameters: - max_norm (float or int) – max norm of the gradients
- norm_type (float or int) – type of the used p-norm. Can be
'inf'
for infinity norm. - params (Iterable[Tensor] or Tensor, optional) – an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state
- State Requirements:
torchbearer.state.MODEL
: Model should have the parameters method
Learning Rate Schedulers¶
-
class
torchbearer.callbacks.torch_scheduler.
CosineAnnealingLR
(T_max, eta_min=0, last_epoch=-1, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import CosineAnnealingLR >>> # Example scheduler which uses cosine learning rate annealing - see PyTorch docs >>> scheduler = MultiStepLR(milestones=[30,80], gamma=0.1) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch
-
class
torchbearer.callbacks.torch_scheduler.
CyclicLR
(base_lr, max_lr, monitor='val_loss', step_size_up=2000, step_size_down=None, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import CyclicLR >>> # Example scheduler which cycles the learning rate between 0.01 and 0.1 >>> scheduler = CyclicLR(0.01, 0.1) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).for_val_steps(10).run(1)
Parameters: - monitor (str) – The name of the quantity in metrics to monitor. (Default value = ‘val_loss’)
- step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch
-
class
torchbearer.callbacks.torch_scheduler.
ExponentialLR
(gamma, last_epoch=-1, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import ExponentialLR >>> # Example scheduler which multiplies the learning rate by 0.1 every epoch >>> scheduler = ExponentialLR(gamma=0.1) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch
-
class
torchbearer.callbacks.torch_scheduler.
LambdaLR
(lr_lambda, last_epoch=-1, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import LambdaLR # Example Trial which performs the two learning rate lambdas from the PyTorch docs >>> lambda1 = lambda epoch: epoch // 30 >>> lambda2 = lambda epoch: 0.95 ** epoch >>> scheduler = LambdaLR(lr_lambda=[lambda1, lambda2]) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch - See:
- PyTorch LambdaLR
-
class
torchbearer.callbacks.torch_scheduler.
MultiStepLR
(milestones, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import MultiStepLR >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 80 >>> # lr = 0.0005 if epoch >= 80 >>> scheduler = MultiStepLR(milestones=[30,80], gamma=0.1) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch - See:
- PyTorch MultiStepLR
-
class
torchbearer.callbacks.torch_scheduler.
ReduceLROnPlateau
(monitor='val_loss', mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import ReduceLROnPlateau >>> # Example scheduler which divides the learning rate by 10 on plateaus of 5 epochs without significant >>> # validation loss decrease, in order to stop overshooting the local minima. new_lr = lr * factor >>> scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).for_val_steps(10).run(1)
Parameters: - monitor (str) – The name of the quantity in metrics to monitor. (Default value = ‘val_loss’)
- step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch
-
class
torchbearer.callbacks.torch_scheduler.
StepLR
(step_size, gamma=0.1, last_epoch=-1, step_on_batch=False)[source]¶ Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import StepLR >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 60 >>> # lr = 0.0005 if 60 <= epoch < 90 >>> scheduler = StepLR(step_size=30, gamma=0.1) >>> trial = Trial(None, callbacks=[scheduler], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: step_on_batch (bool) – If True, step will be called on each training iteration rather than on each epoch - See:
- PyTorch StepLR
-
class
torchbearer.callbacks.torch_scheduler.
TorchScheduler
(scheduler_builder, monitor=None, step_on_batch=False)[source]¶ -
on_end_epoch
(state)[source]¶ Perform some action with the given state as context at the end of each epoch.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_start
(state)[source]¶ Perform some action with the given state as context at the start of a model fit.
Parameters: state (dict) – The current state dict of the Trial
.
-
Learning Rate Finders¶
Weight Decay¶
-
class
torchbearer.callbacks.weight_decay.
L1WeightDecay
(rate=0.0005, params=None)[source]¶ WeightDecay callback which uses an L1 norm with the given rate and parameters. If params is None (default) then the parameters will be retrieved from the model.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import L1WeightDecay # Example Trial which runs a trial with weight decay on the model using an L1 norm >>> decay = L1WeightDecay() >>> trial = Trial(None, callbacks=[decay], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: - rate (float) – The decay rate or lambda
- params (Iterable[Tensor] or Tensor, optional) – an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state
- State Requirements:
torchbearer.state.MODEL
: Model should have the parameters methodtorchbearer.state.LOSS
: Loss should be a tensor that can be incremented
-
class
torchbearer.callbacks.weight_decay.
L2WeightDecay
(rate=0.0005, params=None)[source]¶ WeightDecay callback which uses an L2 norm with the given rate and parameters. If params is None (default) then the parameters will be retrieved from the model.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import L2WeightDecay # Example Trial which runs a trial with weight decay on the model using an L2 norm >>> decay = L2WeightDecay() >>> trial = Trial(None, callbacks=[decay], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: - rate (float) – The decay rate or lambda
- params (Iterable[Tensor] or Tensor, optional) – an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state
- State Requirements:
torchbearer.state.MODEL
: Model should have the parameters methodtorchbearer.state.LOSS
: Loss should be a tensor that can be incremented
-
class
torchbearer.callbacks.weight_decay.
WeightDecay
(rate=0.0005, p=2, params=None)[source]¶ Create a WeightDecay callback which uses the given norm on the given parameters and with the given decay rate. If params is None (default) then the parameters will be retrieved from the model.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import WeightDecay # Example Trial which runs a trial with weight decay on the model >>> decay = WeightDecay() >>> trial = Trial(None, callbacks=[decay], metrics=['loss'], verbose=2).for_steps(10).run(1)
Parameters: - rate (float) – The decay rate or lambda
- p (int) – The norm level
- params (Iterable[Tensor] or Tensor, optional) – an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state
- State Requirements:
torchbearer.state.MODEL
: Model should have the parameters methodtorchbearer.state.LOSS
: Loss should be a tensor that can be incremented
Weight / Bias Initialisation¶
-
class
torchbearer.callbacks.init.
KaimingNormal
(a=0, mode='fan_in', nonlinearity='leaky_relu', modules=None, targets=['Conv', 'Linear', 'Bilinear'])[source]¶ Kaiming Normal weight initialisation. Uses
torch.nn.init.kaiming_normal_
on theweight
attribute of the filtered modules.Example:
>>> import torch >>> import torch.nn as nn >>> from torchbearer import Trial >>> from torchbearer.callbacks.init import KaimingNormal # 100 random data points >>> data = torch.rand(100, 3, 5, 5) >>> example_batch = data[:3] >>> initialiser = KaimingNormal() # Model and trail using kaiming init for some random data >>> model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU()) >>> trial = Trial(model, callbacks=[initialiser]).with_train_data(data, data+5)
@inproceedings{he2015delving, title={Delving deep into rectifiers: Surpassing human-level performance on imagenet classification}, author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, booktitle={Proceedings of the IEEE international conference on computer vision}, pages={1026--1034}, year={2015} }
Parameters: - a (int) – See PyTorch kaiming_uniform_
- mode (str) – See PyTorch kaiming_uniform_
- nonlinearity (str) – See PyTorch kaiming_uniform_
- modules (Iterable[nn.Module] or nn.Module, optional) – an iterable of nn.Modules or a single nn.Module that will have weights initialised, otherwise this is retrieved from the model
- targets (list[String]) – A list of lookup strings to match which modules will be initialised
-
class
torchbearer.callbacks.init.
KaimingUniform
(a=0, mode='fan_in', nonlinearity='leaky_relu', modules=None, targets=['Conv', 'Linear', 'Bilinear'])[source]¶ Kaiming Uniform weight initialisation. Uses
torch.nn.init.kaiming_uniform_
on theweight
attribute of the filtered modules.Example:
>>> import torch >>> import torch.nn as nn >>> from torchbearer import Trial >>> from torchbearer.callbacks.init import KaimingUniform # 100 random data points >>> data = torch.rand(100, 3, 5, 5) >>> example_batch = data[:3] >>> initialiser = KaimingUniform() # Model and trail using kaiming init for some random data >>> model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU()) >>> trial = Trial(model, callbacks=[initialiser]).with_train_data(data, data+5)
@inproceedings{he2015delving, title={Delving deep into rectifiers: Surpassing human-level performance on imagenet classification}, author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, booktitle={Proceedings of the IEEE international conference on computer vision}, pages={1026--1034}, year={2015} }
Parameters: - a (int) –
- mode (str) – See PyTorch kaiming_uniform_
- nonlinearity (str) – See PyTorch kaiming_uniform_
- modules (Iterable[nn.Module] or nn.Module, optional) – an iterable of nn.Modules or a single nn.Module that will have weights initialised, otherwise this is retrieved from the model
- targets (list[String]) – A list of lookup strings to match which modules will be initialised
-
class
torchbearer.callbacks.init.
LsuvInit
(data_item, weight_lambda=None, needed_std=1.0, std_tol=0.1, max_attempts=10, do_orthonorm=True)[source]¶ Layer-sequential unit-variance (LSUV) initialization as described in All you need is a good init and modified from the code by ducha-aiki. To be consistent with the paper, LsuvInit should be preceeded by a ZeroBias init on the Linear and Conv layers.
Example:
>>> import torch >>> import torch.nn as nn >>> from torchbearer import Trial >>> from torchbearer.callbacks.init import LsuvInit # 100 random data points >>> data = torch.rand(100, 3, 5, 5) >>> example_batch = data[:3] >>> lsuv = LsuvInit(example_batch) # Model and trail using lsuv init for some random data >>> model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU()) >>> trial = Trial(model, callbacks=[lsuv]).with_train_data(data, data+5)
@article{mishkin2015all, title={All you need is a good init}, author={Mishkin, Dmytro and Matas, Jiri}, journal={arXiv preprint arXiv:1511.06422}, year={2015} }
Parameters: - data_item (torch.Tensor) – A representative data item to put through the model
- weight_lambda (lambda) – A function that takes a module and returns the weight attribute. If none defaults to module.weight.
- needed_std – See paper, where needed_std is always 1.0
- std_tol – See paper, Tol_{var}
- max_attempts – See paper, T_{max}
- do_orthonorm – See paper, first pre-initialise with orthonormal matricies
- State Requirements:
torchbearer.state.MODEL
: Model should have the modules method if modules is None
-
class
torchbearer.callbacks.init.
WeightInit
(initialiser=<function WeightInit.<lambda>>, modules=None, targets=['Conv', 'Linear', 'Bilinear'])[source]¶ Base class for weight initialisations. Performs the provided function for each module when on_init is called.
Parameters: - initialiser (lambda) – a function which initialises an nn.Module inplace
- modules (Iterable[nn.Module] or nn.Module, optional) – an iterable of nn.Modules or a single nn.Module that will have weights initialised, otherwise this is retrieved from the model
- targets (list[String]) – A list of lookup strings to match which modules will be initialised
- State Requirements:
torchbearer.state.MODEL
: Model should have the modules method if modules is None
-
class
torchbearer.callbacks.init.
XavierNormal
(gain=1, modules=None, targets=['Conv', 'Linear', 'Bilinear'])[source]¶ Xavier Normal weight initialisation. Uses
torch.nn.init.xavier_normal_
on theweight
attribute of the filtered modules.Example:
>>> import torch >>> import torch.nn as nn >>> from torchbearer import Trial >>> from torchbearer.callbacks.init import XavierNormal # 100 random data points >>> data = torch.rand(100, 3, 5, 5) >>> example_batch = data[:3] >>> initialiser = XavierNormal() # Model and trail using Xavier init for some random data >>> model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU()) >>> trial = Trial(model, callbacks=[initialiser]).with_train_data(data, data+5)
@inproceedings{glorot2010understanding, title={Understanding the difficulty of training deep feedforward neural networks}, author={Glorot, Xavier and Bengio, Yoshua}, booktitle={Proceedings of the thirteenth international conference on artificial intelligence and statistics}, pages={249--256}, year={2010} }
Parameters: - gain (int) – See PyTorch xavier_normal_
- modules (Iterable[nn.Module] or nn.Module, optional) – an iterable of nn.Modules or a single nn.Module that will have weights initialised, otherwise this is retrieved from the model
- targets (list[String]) – A list of lookup strings to match which modules will be initialised
-
class
torchbearer.callbacks.init.
XavierUniform
(gain=1, modules=None, targets=['Conv', 'Linear', 'Bilinear'])[source]¶ Xavier Uniform weight initialisation. Uses
torch.nn.init.xavier_uniform_
on theweight
attribute of the filtered modules.Example:
>>> import torch >>> import torch.nn as nn >>> from torchbearer import Trial >>> from torchbearer.callbacks.init import XavierUniform # 100 random data points >>> data = torch.rand(100, 3, 5, 5) >>> example_batch = data[:3] >>> initialiser = XavierUniform() # Model and trail using Xavier init for some random data >>> model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU()) >>> trial = Trial(model, callbacks=[initialiser]).with_train_data(data, data+5)
@inproceedings{glorot2010understanding, title={Understanding the difficulty of training deep feedforward neural networks}, author={Glorot, Xavier and Bengio, Yoshua}, booktitle={Proceedings of the thirteenth international conference on artificial intelligence and statistics}, pages={249--256}, year={2010} }
Parameters: - gain (int) –
- modules (Iterable[nn.Module] or nn.Module, optional) – an iterable of nn.Modules or a single nn.Module that will have weights initialised, otherwise this is retrieved from the model
- targets (list[String]) – A list of lookup strings to match which modules will be initialised
-
class
torchbearer.callbacks.init.
ZeroBias
(modules=None, targets=['Conv', 'Linear', 'Bilinear'])[source]¶ Zero initialisation for the
bias
attributes of filtered modules. This is recommended for use in conjunction with weight initialisation schemes.Example:
>>> import torch >>> import torch.nn as nn >>> from torchbearer import Trial >>> from torchbearer.callbacks.init import ZeroBias # 100 random data points >>> data = torch.rand(100, 3, 5, 5) >>> example_batch = data[:3] >>> initialiser = ZeroBias() # Model and trail using zero bias init for some random data >>> model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU()) >>> trial = Trial(model, callbacks=[initialiser]).with_train_data(data, data+5)
Parameters: - modules (Iterable[nn.Module] or nn.Module, optional) – an iterable of nn.Modules or a single nn.Module that will have weights initialised, otherwise this is retrieved from the model
- targets (list[String]) – A list of lookup strings to match which modules will be initialised
Regularisers¶
-
class
torchbearer.callbacks.cutout.
Cutout
(n_holes, length, constant=0.0, seed=None)[source]¶ Cutout callback which randomly masks out patches of image data. Implementation a modified version of the code found here.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import Cutout # Example Trial which does Cutout regularisation >>> cutout = Cutout(1, 10) >>> trial = Trial(None, callbacks=[cutout], metrics=['acc'])
@article{devries2017improved, title={Improved regularization of convolutional neural networks with Cutout}, author={DeVries, Terrance and Taylor, Graham W}, journal={arXiv preprint arXiv:1708.04552}, year={2017} }
Parameters: - n_holes (int) – Number of patches to cut out of each image.
- length (int) – The length (in pixels) of each square patch.
- constant (float) – Constant value for each square patch
- seed – Random seed
- State Requirements:
torchbearer.state.X
: State should have the current data stored
-
class
torchbearer.callbacks.cutout.
RandomErase
(n_holes, length, seed=None)[source]¶ Random erase callback which replaces random patches of image data with random noise. Implementation a modified version of the cutout code found here.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import RandomErase # Example Trial which does Cutout regularisation >>> erase = RandomErase(1, 10) >>> trial = Trial(None, callbacks=[erase], metrics=['acc'])
@article{zhong2017random, title={Random erasing data augmentation}, author={Zhong, Zhun and Zheng, Liang and Kang, Guoliang and Li, Shaozi and Yang, Yi}, journal={arXiv preprint arXiv:1708.04896}, year={2017} }
Parameters: - n_holes (int) – Number of patches to cut out of each image.
- length (int) – The length (in pixels) of each square patch.
- seed – Random seed
- State Requirements:
torchbearer.state.X
: State should have the current data stored
-
class
torchbearer.callbacks.cutout.
CutMix
(alpha, classes=-1, seed=None)[source]¶ Cutmix callback which replaces a random patch of image data with the corresponding patch from another image. This callback also converts labels to one hot before combining them according to the lambda parameters, sampled from a beta distribution as is done in the paper.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import CutMix # Example Trial which does CutMix regularisation >>> cutmix = CutMix(1, classes=10) >>> trial = Trial(None, callbacks=[cutmix], metrics=['acc'])
@article{yun2019cutmix, title={Cutmix: Regularization strategy to train strong classifiers with localizable features}, author={Yun, Sangdoo and Han, Dongyoon and Oh, Seong Joon and Chun, Sanghyuk and Choe, Junsuk and Yoo, Youngjoon}, journal={arXiv preprint arXiv:1905.04899}, year={2019} }
Parameters: - alpha (float) – The alpha value for the beta distribution.
- classes (int) – The number of classes for conversion to one hot.
- seed – Random seed
- State Requirements:
torchbearer.state.X
: State should have the current data storedtorchbearer.state.Y_TRUE
: State should have the current data stored
-
class
torchbearer.callbacks.mixup.
Mixup
(alpha=1.0, lam=-10.0)[source]¶ Perform mixup on the model inputs. Requires use of
MixupInputs.loss()
, otherwise lambdas can be found in state underMIXUP_LAMBDA
. Model targets will be a tuple containing the original target and permuted target.Note
The accuracy metric for mixup is different on training to deal with the different targets,
but for validation it is exactly the categorical accuracy, despite being called “val_mixup_acc”
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import Mixup # Example Trial which does Mixup regularisation >>> mixup = Mixup(0.9) >>> trial = Trial(None, criterion=Mixup.loss, callbacks=[mixup], metrics=['acc'])
@inproceedings{zhang2018mixup, title={mixup: Beyond Empirical Risk Minimization}, author={Hongyi Zhang and Moustapha Cisse and Yann N. Dauphin and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2018} }
Parameters: alpha (float) – The alpha value to use in the beta distribution. -
RANDOM
= -10.0¶
-
-
class
torchbearer.callbacks.sample_pairing.
SamplePairing
(policy=None)[source]¶ Perform SamplePairing on the model inputs. This is the process of averaging each image with another random image without changing the targets. The key here is to use the policy function to only do this some of the time.
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import SamplePairing # Example Trial which does Sample Pairing regularisation with the policy from the paper >>> pairing = SamplePairing() >>> trial = Trial(None, criterion=Mixup.loss, callbacks=[pairing], metrics=['acc'])
@article{inoue2018data, title={Data augmentation by pairing samples for images classification}, author={Inoue, Hiroshi}, journal={arXiv preprint arXiv:1801.02929}, year={2018} }
Parameters: policy – A function of state which returns True if the current batch should be paired. -
static
default_policy
(start_epoch, end_epoch, on_epochs, off_epochs)[source]¶ Return a policy which performs sample pairing according to the process defined in the paper.
Parameters: - start_epoch (int) – Epoch to start pairing on
- end_epoch (int) – Epoch to end pairing on (and begin fine-tuning)
- on_epochs (int) – Number of epochs to run sample pairing for before a break
- off_epochs (int) – Number of epochs to break for
Returns: A policy function
-
static
-
class
torchbearer.callbacks.label_smoothing.
LabelSmoothingRegularisation
(epsilon, classes=-1)[source]¶ Perform Label Smoothing Regularisation (LSR) on the targets during training. This involves converting the target to a one-hot vector and smoothing according to the value epsilon.
Note
Requires a multi-label loss, such as nn.BCELoss
Example:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import LabelSmoothingRegularisation # Example Trial which does label smoothing regularisation >>> smoothing = LabelSmoothingRegularisation() >>> trial = Trial(None, criterion=nn.BCELoss(), callbacks=[smoothing], metrics=['acc'])
@article{szegedy2015rethinking, title={Rethinking the inception architecture for computer vision. arXiv 2015}, author={Szegedy, Christian and Vanhoucke, Vincent and Ioffe, Sergey and Shlens, Jonathon and Wojna, Zbigniew}, journal={arXiv preprint arXiv:1512.00567}, volume={1512}, year={2015} }
Parameters: - epsilon (float) – The epsilon parameter from the paper
- classes (int) – The number of target classes, not required if the target is already one-hot encoded
-
on_sample
(state)[source]¶ Perform some action with the given state as context after data has been sampled from the generator.
Parameters: state (dict) – The current state dict of the Trial
.
Unpack State¶
-
torchbearer.callbacks.
unpack_state
¶ alias of
torchbearer.callbacks.unpack_state
Decorators¶
Main¶
The main callback decorators simply take a function and bind it to a callback point, returning the result.
-
torchbearer.callbacks.decorators.
on_init
(func)[source]¶ The
on_init()
decorator is used to initialise aCallback
withon_init()
calling the decorated functionExample:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import on_init # Example callback on start >>> @on_init ... def print_callback(state): ... print('Initialised trial.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Initialised trial.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_init()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start
(func)[source]¶ The
on_start()
decorator is used to initialise aCallback
withon_start()
calling the decorated functionExample:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import on_start # Example callback on start >>> @on_start ... def print_callback(state): ... print('Starting training.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Starting training.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_start()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start_epoch
(func)[source]¶ The
on_start_epoch()
decorator is used to initialise aCallback
withon_start_epoch()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_start_epoch # Example callback running at start of each epoch >>> @on_start_epoch ... def print_callback(state): ... print('Starting epoch {}.'.format(state[torchbearer.EPOCH])) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Starting epoch 0. Args: func (function): The function(state) to *decorate*
Returns: Initialised callback with on_start_epoch()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start_training
(func)[source]¶ The
on_start_training()
decorator is used to initialise aCallback
withon_start_training()
calling the decorated functionExample:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import on_start_training # Example callback running at start of the training pass >>> @on_start_training ... def print_callback(state): ... print('Starting training.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Starting training.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_start_training()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_sample
(func)[source]¶ The
on_sample()
decorator is used to initialise aCallback
withon_sample()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_sample # Example callback running each time a sample is taken from the dataset >>> @on_sample ... def print_callback(state): ... print('Current sample {}.'.format(state[torchbearer.X])) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Current sample None.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_sample()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_forward
(func)[source]¶ The
on_forward()
decorator is used to initialise aCallback
withon_forward()
calling the decorated functionExample:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import on_forward # Example callback running after each training forward pass of the torch model >>> @on_forward ... def print_callback(state): ... print('Evaluated training batch.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Evaluated training batch.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_forward()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_criterion
(func)[source]¶ The
on_criterion()
decorator is used to initialise aCallback
withon_criterion()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_criterion # Example callback running after each evaluation of the loss >>> @on_criterion ... def print_callback(state): ... print('Current loss {}.'.format(state[torchbearer.LOSS].item())) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Current loss 0.0.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_criterion()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_backward
(func)[source]¶ The
on_backward()
decorator is used to initialise aCallback
withon_backward()
calling the decorated functionExample:
>>> from torchbearer import Trial >>> from torchbearer.callbacks import on_backward # Example callback running after each backward pass of the torch model >>> @on_backward ... def print_callback(state): ... print('Doing backward.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Doing backward.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_backward()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_step_training
(func)[source]¶ The
on_step_training()
decorator is used to initialise aCallback
withon_step_training()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_step_training # Example callback running after each training step >>> @on_step_training ... def print_callback(state): ... print('Step {}.'.format(state[torchbearer.BATCH])) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Step 0.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_step_training()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end_training
(func)[source]¶ The
on_end_training()
decorator is used to initialise aCallback
withon_end_training()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_end_training # Example callback running after each training pass >>> @on_end_training ... def print_callback(state): ... print('Finished training pass.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Finished training pass.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_end_training()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_start_validation
(func)[source]¶ The
on_start_validation()
decorator is used to initialise aCallback
withon_start_validation()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_start_validation # Example callback running when each validation pass starts. >>> @on_start_validation ... def print_callback(state): ... print('Starting validation.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).for_val_steps(1).run() Starting validation.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_start_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_sample_validation
(func)[source]¶ The
on_sample_validation()
decorator is used to initialise aCallback
withon_sample_validation()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_sample_validation # Example callback running after each validation sample is drawn. >>> @on_sample_validation ... def print_callback(state): ... print('Sampled validation data {}.'.format(state[torchbearer.X])) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).for_val_steps(1).run() Sampled validation data None.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_sample_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_forward_validation
(func)[source]¶ The
on_forward_validation()
decorator is used to initialise aCallback
withon_forward_validation()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_forward_validation # Example callback running after each torch model forward pass in validation. >>> @on_forward_validation ... def print_callback(state): ... print('Evaluated validation batch.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).for_val_steps(1).run() Evaluated validation batch.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_forward_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_criterion_validation
(func)[source]¶ The
on_criterion_validation()
decorator is used to initialise aCallback
withon_criterion_validation()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_criterion_validation # Example callback running after each criterion evaluation in validation. >>> @on_criterion_validation ... def print_callback(state): ... print('Current val loss {}.'.format(state[torchbearer.LOSS].item())) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).for_val_steps(1).run() Current val loss 0.0.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_criterion_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_step_validation
(func)[source]¶ The
on_step_validation()
decorator is used to initialise aCallback
withon_step_validation()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_step_validation # Example callback running at the end of each validation step. >>> @on_step_validation ... def print_callback(state): ... print('Validation step {}.'.format(state[torchbearer.BATCH])) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).for_val_steps(1).run() Validation step 0.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_step_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end_validation
(func)[source]¶ The
on_end_validation()
decorator is used to initialise aCallback
withon_end_validation()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_end_validation # Example callback running at the end of each validation pass. >>> @on_end_validation ... def print_callback(state): ... print('Finished validating.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).for_val_steps(1).run() Finished validating.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_end_validation()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end_epoch
(func)[source]¶ The
on_end_epoch()
decorator is used to initialise aCallback
withon_end_epoch()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_end_epoch # Example callback running each epoch >>> @on_end_epoch ... def print_callback(state): ... print('Finished epoch {}.'.format(state[torchbearer.EPOCH])) >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Finished epoch 0.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_end_epoch()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_checkpoint
(func)[source]¶ The
on_checkpoint()
decorator is used to initialise aCallback
withon_checkpoint()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_checkpoint # Example callback running at checkpoint time. >>> @on_checkpoint ... def print_callback(state): ... print('Checkpointing.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Checkpointing.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_checkpoint()
calling funcReturn type: Callback
-
torchbearer.callbacks.decorators.
on_end
(func)[source]¶ The
on_end()
decorator is used to initialise aCallback
withon_end()
calling the decorated functionExample:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import on_end # Example callback running after all training is finished. >>> @on_end ... def print_callback(state): ... print('Finished training model.') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() Finished training model.
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback with on_end()
calling funcReturn type: Callback
Utility¶
Alongside the base callback decorators that simply bind a function to a callback point, Torchbearer has a number of utility decorators that help simplify callback construction.
-
torchbearer.callbacks.decorators.
add_to_loss
(func)[source] The
add_to_loss()
decorator is used to initialise aCallback
with the value returned from func being added to the lossExample:
>>> import torch >>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import add_to_loss # Example callback to add a quantity to the loss each step. >>> @add_to_loss ... def loss_callback(state): ... return torch.Tensor([1.125]) >>> trial = Trial(None, callbacks=[loss_callback], metrics=['loss']).for_steps(1).run() >>> print(trial[0][1]['loss']) 1.125
Parameters: func (function) – The function(state) to decorate Returns: Initialised callback which adds the returned value from func to the loss Return type: Callback
-
torchbearer.callbacks.decorators.
once
(fcn)[source] Decorator to fire a callback once in the lifetime of the callback. If the callback is a class method, each instance of the class will fire only once. For functions, only the first instance will fire (even if more than one function is present in the callback list).
Example:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import once, on_step_training # Example callback to be called exactly once on the very first training step >>> @once ... @on_step_training ... def print_callback(state): ... print('This happens once ever') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() This happens once ever
Parameters: fcn (function) – the torchbearer callback function to decorate. Returns: the decorator
-
torchbearer.callbacks.decorators.
once_per_epoch
(fcn)[source] Decorator to fire a callback once (on the first call) in any given epoch. If the callback is a class method, each instance of the class will fire once per epoch. For functions, only the first instance will fire (even if more than one function is present in the callback list).
Example:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import once_per_epoch, on_step_training # Example callback to be called exactly once per epoch, on the first training step >>> @once_per_epoch ... @on_step_training ... def print_callback(state): ... print('This happens once per epoch') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(1).run() This happens once per epoch
Note
The decorated callback may exhibit unusual behaviour if it is reused
Parameters: fcn (function) – the torchbearer callback function to decorate. Returns: the decorator
-
torchbearer.callbacks.decorators.
only_if
(condition_expr)[source] Decorator to fire a callback only if the given conditional expression function returns True. The conditional expression can be a function of state or self and state. If the decorated function is not a class method (i.e. it does not take state) the decorated function will be passed instead. This enables the storing of temporary variables.
Example:
>>> import torchbearer >>> from torchbearer import Trial >>> from torchbearer.callbacks import only_if, on_step_training # Example callback to be called only when the given condition is true on each training step >>> @only_if(lambda state: state[torchbearer.BATCH] == 100) ... @on_step_training ... def print_callback(state): ... print('This is the 100th batch') >>> trial = Trial(None, callbacks=[print_callback]).for_steps(101).run() This is the 100th batch
Parameters: condition_expr (function(self, state) or function(self)) – a function/lambda which takes state and optionally self that must evaluate to true for the decorated torchbearer callback to be called. The state object passed to the callback will be passed as an argument to the condition function. Returns: the decorator