torchbearer

Trial

class torchbearer.trial.CallbackListInjection(callback, callback_list)[source]

This class allows for an callback to be injected into a callback list, without masking the methods available for mutating the list. In this way, callbacks (such as printers) can be injected seamlessly into the methods of the trial class.

Parameters:
  • callback – The callback to inject
  • callback_list (CallbackList) – The underlying callback list
append(callback_list)[source]
copy()[source]
load_state_dict(state_dict)[source]

Resume this callback list from the given state. Callbacks must be given in the same order for this to work.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:CallbackList
state_dict()[source]

Get a dict containing all of the callback states.

Returns:A dict containing parameters and persistent buffers.
Return type:dict
class torchbearer.trial.Sampler(batch_loader)[source]

Sampler wraps a batch loader function and executes it when Sampler.sample() is called

Parameters:batch_loader (function) – The batch loader to execute
sample(state)[source]
class torchbearer.trial.Trial(model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False, verbose=2)[source]

The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.

Parameters:
  • model (torch.nn.Module) – The base pytorch model
  • optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
  • criterion (function or None) – The final loss criterion that provides a loss value to the optimizer
  • metrics (list) – The list of torchbearer.Metric instances to process during fitting
  • callbacks (list) – The list of torchbearer.Callback instances to call during fitting
  • pass_state (bool) – If True, the torchbearer state will be passed to the model during fitting
  • verbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
cpu()[source]

Moves all model parameters and buffers to the CPU.

Returns:self
Return type:Trial
cuda(device=None)[source]

Moves all model parameters and buffers to the GPU.

Parameters:device (int, optional) – if specified, all parameters will be copied to that device
Returns:self
Return type:Trial
eval()[source]

Set model and metrics to evaluation mode

Returns:self
Return type:Trial
evaluate(verbose=-1, data_key=None)[source]

Evaluate this trial on the validation data.

Parameters:
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
  • data_key (StateKey) – Optional key for the data to evaluate on. Default: torchbearer.VALIDATION_DATA
Returns:

The final metric values

Return type:

dict

for_steps(train_steps=None, val_steps=None, test_steps=None)[source]

Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience.

Parameters:
  • train_steps (int, optional) – The number of training steps per epoch to run
  • val_steps (int, optional) – The number of validation steps per epoch to run
  • test_steps (int, optional) – The number of test steps per epoch to run (when using predict())
Returns:

self

Return type:

Trial

for_test_steps(steps)[source]

Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience.

Parameters:steps (int) – The number of test steps per epoch to run (when using predict())
Returns:self
Return type:Trial
for_train_steps(steps)[source]

Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience.

Parameters:steps (int) – The number of training steps per epoch to run
Returns:self
Return type:Trial
for_val_steps(steps)[source]

Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience.

Parameters:steps (int) – The number of validation steps per epoch to run
Returns:self
Return type:Trial
load_state_dict(state_dict, resume=True, **kwargs)[source]

Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.

Parameters:
  • state_dict (dict) – The state dict to reload
  • resume – If True, resume from the given state. Else, just load in the model weights.
  • kwargs – See: torch.nn.Module.load_state_dict
Returns:

self

Return type:

Trial

predict(verbose=-1, data_key=None)[source]

Determine predictions for this trial on the test data.

Parameters:
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
  • data_key (StateKey) – Optional key for the data to predict on. Default: torchbearer.TEST_DATA
Returns:

Model outputs as a list

Return type:

list

replay(callbacks=[], verbose=2, one_batch=False)[source]

Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

Parameters:
  • callbacks (list) – List of callbacks to be run during the replay
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
  • one_batch (bool) – If True, only one batch per epoch is replayed. If False, all batches are replayed
Returns:

self

Return type:

Trial

run(epochs=1, verbose=-1)[source]

Run this trial for the given number of epochs, starting from the last trained epoch.

Parameters:
  • epochs (int, optional) – The number of epochs to run for
  • verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
  • If -1 (progress,) – Automatic
State Requirements:
Returns:The model history (list of tuple of steps summary and epoch metric dicts)
Return type:list
state_dict(**kwargs)[source]

Get a dict containing the model and optimizer states, as well as the model history.

Parameters:kwargs – See: torch.nn.Module.state_dict
Returns:A dict containing parameters and persistent buffers.
Return type:dict
to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

Parameters:
Returns:

self

Return type:

Trial

train()[source]

Set model and metrics to training mode.

Returns:self
Return type:Trial
with_generators(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]

Use this trial with the given generators. Returns self so that methods can be chained for convenience.

Parameters:
  • train_generator (DataLoader) – The training data generator to use during calls to run()
  • val_generator (DataLoader) – The validation data generator to use during calls to run() and evaluate()
  • test_generator (DataLoader) – The testing data generator to use during calls to predict()
  • train_steps (int) – The number of steps per epoch to take when using the training generator
  • val_steps (int) – The number of steps per epoch to take when using the validation generator
  • test_steps (int) – The number of steps per epoch to take when using the testing generator
Returns:

self

Return type:

Trial

with_test_data(x, batch_size=1, num_workers=1, steps=None)[source]

Use this trial with the given test data. Returns self so that methods can be chained for convenience.

Parameters:
  • x (torch.Tensor) – The test x data to use during calls to predict()
  • batch_size (int) – The size of each batch to sample from the data
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

with_test_generator(generator, steps=None)[source]

Use this trial with the given test generator. Returns self so that methods can be chained for convenience.

Parameters:
  • generator (DataLoader) – The test data generator to use during calls to predict()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

with_train_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]

Use this trial with the given train data. Returns self so that methods can be chained for convenience.

Parameters:
  • x (torch.Tensor) – The train x data to use during calls to run()
  • y (torch.Tensor) – The train labels to use during calls to run()
  • batch_size (int) – The size of each batch to sample from the data
  • shuffle (bool) – If True, then data will be shuffled each epoch
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

with_train_generator(generator, steps=None)[source]

Use this trial with the given train generator. Returns self so that methods can be chained for convenience.

Parameters:
  • generator (DataLoader) – The train data generator to use during calls to run()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

with_val_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]

Use this trial with the given validation data. Returns self so that methods can be chained for convenience.

Parameters:
  • x (torch.Tensor) – The validation x data to use during calls to run() and evaluate()
  • y (torch.Tensor) – The validation labels to use during calls to run() and evaluate()
  • batch_size (int) – The size of each batch to sample from the data
  • shuffle (bool) – If True, then data will be shuffled each epoch
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

with_val_generator(generator, steps=None)[source]

Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.

Parameters:
  • generator (DataLoader) – The validation data generator to use during calls to run() and evaluate()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

torchbearer.trial.deep_to(batch, device, dtype)[source]

Static method to call to() on tensors or tuples. All items in tuple will have deep_to() called :param batch: The mini-batch which requires a to() call :type batch: tuple, list, torch.Tensor :param device: The desired device of the batch :type device: torch.device :param dtype: The desired datatype of the batch :type dtype: torch.dtype :return: The moved or casted batch :rtype: tuple, list, torch.Tensor

torchbearer.trial.fluent(func)[source]

Decorator for class methods which forces return of self.

torchbearer.trial.get_printer(verbose, validation_label_letter)[source]
torchbearer.trial.inject_callback(callback)[source]

Decorator to inject a callback into the callback list and remove the callback after the decorated function has executed

Parameters:callback (Callback) – Callback to be injected
Returns:the decorator
torchbearer.trial.inject_printer(validation_label_letter='v')[source]

The inject printer decorator is used to inject the appropriate printer callback, according to the verbosity level.

Parameters:validation_label_letter – The validation label letter to use
Returns:A decorator
torchbearer.trial.inject_sampler(data_key, predict=False)[source]

Decorator to inject a Sampler into state[torchbearer.SAMPLER] along with the specified generator into state[torchbearer.GENERATOR] and number of steps into state[torchbearer.STEPS] :param data_key: Key for the data to inject :type data_key: StateKey :param predict: If true, the prediction batch loader is used, if false the standard data loader is used :type predict: bool :return: the decorator

torchbearer.trial.load_batch_none(state)[source]

Load a none (none, none) tuple mini-batch into state

Parameters:state (dict[str,any]) – The current state dict of the Trial.
torchbearer.trial.load_batch_predict(state)[source]

Load a prediction (input data, target) or (input data) mini-batch from iterator into state

Parameters:state (dict[str,any]) – The current state dict of the Trial.
torchbearer.trial.load_batch_standard(state)[source]

Load a standard (input data, target) tuple mini-batch from iterator into state

Parameters:state (dict[str,any]) – The current state dict of the Trial.
torchbearer.trial.update_device_and_dtype(state, *args, **kwargs)[source]

Function get data type and device values from the args / kwargs and update state.

Parameters:
  • state (State) – The dict to update
  • args – Arguments to the Trial.to() function
  • kwargs – Keyword arguments to the Trial.to() function
Returns:

device, dtype pair

Return type:

tuple

Model (Deprecated)

class torchbearer.torchbearer.Model(model, optimizer, criterion=None, metrics=[])[source]

Deprecated since version 0.2.0: Use Trial instead.

Create torchbearermodel which wraps a base torchmodel and provides a training environment surrounding it

Parameters:
  • model (torch.nn.Module) – The base pytorch model
  • optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
  • criterion (function or None) – The final loss criterion that provides a loss value to the optimizer
  • metrics (list) – Additional metrics for display and use within callbacks
cpu()[source]

Moves all model parameters and buffers to the CPU.

Returns:Self torchbearermodel
Return type:Model
cuda(device=None)[source]

Moves all model parameters and buffers to the GPU.

Parameters:device (int, optional) – if specified, all parameters will be copied to that device
Returns:Self torchbearermodel
Return type:Model
eval()[source]

Set model and metrics to evaluation mode

evaluate(x=None, y=None, batch_size=32, verbose=2, steps=None, pass_state=False)[source]

Perform an evaluation loop on given data and label tensors to evaluate metrics

Parameters:
  • x (torch.Tensor) – The input data tensor
  • y (torch.Tensor) – The target labels for data tensor x
  • batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
  • steps (int) – The number of evaluation mini-batches to run
  • pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns:

The dictionary containing final metrics

Return type:

dict[str,any]

evaluate_generator(generator, verbose=2, steps=None, pass_state=False)[source]

Perform an evaluation loop on given data generator to evaluate metrics

Parameters:
  • generator (DataLoader) – The evaluation data generator (usually a pytorch DataLoader)
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
  • steps (int) – The number of evaluation mini-batches to run
  • pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns:

The dictionary containing final metrics

Return type:

dict[str,any]

fit(x, y, batch_size=None, epochs=1, verbose=2, callbacks=[], validation_split=None, validation_data=None, shuffle=True, initial_epoch=0, steps_per_epoch=None, validation_steps=None, workers=1, pass_state=False)[source]

Perform fitting of a model to given data and label tensors

Parameters:
  • x (torch.Tensor) – The input data tensor
  • y (torch.Tensor) – The target labels for data tensor x
  • batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
  • epochs (int) – The number of training epochs to be run (each sample from the dataset is viewed exactly once)
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
  • callbacks (list) – The list of torchbearer callbacks to be called during training and validation
  • validation_split (float) – Fraction of the training dataset to be set aside for validation testing
  • validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data tensor
  • shuffle (bool) – If True mini-batches of training/validation data are randomly selected, if False mini-batches samples are selected in order defined by dataset
  • initial_epoch (int) – The integer value representing the first epoch - useful for continuing training after a number of epochs
  • steps_per_epoch (int) – The number of training mini-batches to run per epoch
  • validation_steps (int) – The number of validation mini-batches to run per epoch
  • workers (int) – The number of cpu workers devoted to batch loading and aggregating
  • pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns:

The final state context dictionary

Return type:

dict[str,any]

fit_generator(generator, train_steps=None, epochs=1, verbose=2, callbacks=[], validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False)[source]

Perform fitting of a model to given data generator

Parameters:
  • generator (DataLoader) – The training data generator (usually a pytorch DataLoader)
  • train_steps (int) – The number of training mini-batches to run per epoch
  • epochs (int) – The number of training epochs to be run (each sample from the dataset is viewed exactly once)
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
  • callbacks (list) – The list of torchbearer callbacks to be called during training and validation
  • validation_generator (DataLoader) – The validation data generator (usually a pytorch DataLoader)
  • validation_steps (int) – The number of validation mini-batches to run per epoch
  • initial_epoch (int) – The integer value representing the first epoch - useful for continuing training after a number of epochs
  • pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns:

The final state context dictionary

Return type:

dict[str,any]

load_state_dict(state_dict, **kwargs)[source]

Copies parameters and buffers from state_dict() into this module and its descendants.

Parameters:
predict(x=None, batch_size=32, verbose=2, steps=None, pass_state=False)[source]

Perform a prediction loop on given data tensor to predict labels

Parameters:
  • x (torch.Tensor) – The input data tensor
  • batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
  • steps (int) – The number of evaluation mini-batches to run
  • pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns:

Tensor of final predicted labels

Return type:

torch.Tensor

predict_generator(generator, verbose=2, steps=None, pass_state=False)[source]

Perform a prediction loop on given data generator to predict labels

Parameters:
  • generator (DataLoader) – The prediction data generator (usually a pytorch DataLoader)
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
  • steps (int) – The number of evaluation mini-batches to run
  • pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns:

Tensor of final predicted labels

Return type:

torch.Tensor

state_dict(**kwargs)[source]
Parameters:kwargs

See: torch.nn.Module.state_dict

Returns:A dict containing parameters and persistent buffers.
Return type:dict
to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

Parameters:
Returns:

Self torchbearermodel

Return type:

Model

train()[source]

Set model and metrics to training mode

State

The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced during model fitting. This module defines classes for interacting with state and all of the built in state keys used throughout torchbearer. The state_key() function can be used to create custom state keys for use in callbacks or metrics.

Example:

from torchbearer import state_key
MY_KEY = state_key('my_test_key')
torchbearer.state.BACKWARD_ARGS = backward_args

The optional arguments which should be passed to the backward call

torchbearer.state.BATCH = t

The current batch number

torchbearer.state.CALLBACK_LIST = callback_list

The CallbackList object which is called by the Trial

torchbearer.state.CRITERION = criterion

The criterion to use when model fitting

torchbearer.state.DATA = data

The string name of the current data

torchbearer.state.DATA_TYPE = dtype

The data type of tensors in use by the model, match this to avoid type issues

torchbearer.state.DEVICE = device

The device currently in use by the Trial and PyTorch model

torchbearer.state.EPOCH = epoch

The current epoch number

torchbearer.state.FINAL_PREDICTIONS = final_predictions

The key which maps to the predictions over the dataset when calling predict

torchbearer.state.GENERATOR = generator

The current data generator (DataLoader)

torchbearer.state.HISTORY = history

The history list of the Trial instance

torchbearer.state.INPUT = x

The current batch of inputs

torchbearer.state.ITERATOR = iterator

The current iterator

torchbearer.state.LOSS = loss

The current value for the loss

torchbearer.state.MAX_EPOCHS = max_epochs

The total number of epochs to run for

torchbearer.state.METRICS = metrics

The metric dict from the current batch of data

torchbearer.state.METRIC_LIST = metric_list

The list of metrics in use by the Trial

torchbearer.state.MODEL = model

The PyTorch module / model that will be trained

torchbearer.state.OPTIMIZER = optimizer

The optimizer to use when model fitting

torchbearer.state.PREDICTION = y_pred

The current batch of predictions

torchbearer.state.SAMPLER = sampler

The sampler which loads data from the generator onto the correct device

torchbearer.state.SELF = self

A self refrence to the Trial object for persistence etc.

torchbearer.state.STEPS = steps

The current number of steps per epoch

torchbearer.state.STOP_TRAINING = stop_training

A flag that can be set to true to stop the current fit call

class torchbearer.state.State[source]

State dictionary that behaves like a python dict but accepts StateKeys

get_key(statekey)[source]
update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

class torchbearer.state.StateKey(key)[source]

StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.

Parameters:key (String) – Base key
process(state)[source]

MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.

If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.

Attributes and the return value of a MagicMock will also be MagicMocks.

process_final(state)[source]

MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.

If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.

Attributes and the return value of a MagicMock will also be MagicMocks.

torchbearer.state.TARGET = y_true

The current batch of ground truth data

torchbearer.state.TEST_DATA = test_data

The flag representing test data

torchbearer.state.TEST_GENERATOR = test_generator

The test data generator in the Trial object

torchbearer.state.TEST_STEPS = test_steps

The number of test steps to take

torchbearer.state.TIMINGS = timings

The timings keys used by the timer callback

torchbearer.state.TRAIN_DATA = train_data

The flag representing train data

torchbearer.state.TRAIN_GENERATOR = train_generator

The train data generator in the Trial object

torchbearer.state.TRAIN_STEPS = train_steps

The number of train steps to take

torchbearer.state.VALIDATION_DATA = validation_data

The flag representing validation data

torchbearer.state.VALIDATION_GENERATOR = validation_generator

The validation data generator in the Trial object

torchbearer.state.VALIDATION_STEPS = validation_steps

The number of validation steps to take

torchbearer.state.VERSION = torchbearer_version

The torchbearer version

torchbearer.state.X = x

The current batch of inputs

torchbearer.state.Y_PRED = y_pred

The current batch of predictions

torchbearer.state.Y_TRUE = y_true

The current batch of ground truth data

torchbearer.state.state_key(key)[source]

Computes and returns a non-conflicting key for the state dictionary when given a seed key

Parameters:key (String) – The seed key - basis for new state key
Returns:New state key
Return type:StateKey

Utilities

class torchbearer.cv_utils.DatasetValidationSplitter(dataset_len, split_fraction, shuffle_seed=None)[source]
get_train_dataset(dataset)[source]

Creates a training dataset from existing dataset

Parameters:dataset (torch.utils.data.Dataset) – Dataset to be split into a training dataset
Returns:Training dataset split from whole dataset
Return type:torch.utils.data.Dataset
get_val_dataset(dataset)[source]

Creates a validation dataset from existing dataset

Parameters:dataset (torch.utils.data.Dataset) – Dataset to be split into a validation dataset
Returns:Validation dataset split from whole dataset
Return type:torch.utils.data.Dataset
torchbearer.cv_utils.get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True)[source]

Generate validation and training datasets from whole dataset tensors

Parameters:
  • x (torch.Tensor) – Data tensor for dataset
  • y (torch.Tensor) – Label tensor for dataset
  • validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
  • validation_split (float) – Fraction of dataset to be used for validation
  • shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns:

Training and validation datasets

Return type:

tuple

torchbearer.cv_utils.train_valid_splitter(x, y, split, shuffle=True)[source]

Generate training and validation tensors from whole dataset data and label tensors

Parameters:
  • x (torch.Tensor) – Data tensor for whole dataset
  • y (torch.Tensor) – Label tensor for whole dataset
  • split (float) – Fraction of dataset to be used for validation
  • shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns:

Training and validation tensors (training data, training labels, validation data, validation labels)

Return type:

tuple