torchbearer

Trial

class torchbearer.trial.CallbackListInjection(callback, callback_list)[source]

This class allows for an callback to be injected into a callback list, without masking the methods available for mutating the list. In this way, callbacks (such as printers) can be injected seamlessly into the methods of the trial class.

Parameters:
append(callback_list)[source]
copy()[source]
load_state_dict(state_dict)[source]

Resume this callback list from the given state. Callbacks must be given in the same order for this to work.

Parameters:state_dict (dict) – The state dict to reload
Returns:self
Return type:CallbackList
state_dict()[source]

Get a dict containing all of the callback states.

Returns:A dict containing parameters and persistent buffers.
Return type:dict
class torchbearer.trial.MockOptimizer[source]

The Mock Optimizer will be used inplace of an optimizer in the event that none is passed to the Trial class.

add_param_group(param_group)[source]
load_state_dict(state_dict)[source]
state_dict()[source]
step(closure=None)[source]
zero_grad()[source]
class torchbearer.trial.Sampler(batch_loader)[source]

Sampler wraps a batch loader function and executes it when Sampler.sample() is called

Parameters:batch_loader (func) – The batch loader to execute
sample(state)[source]
class torchbearer.trial.Trial(model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2)[source]

The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.

@article{2018torchbearer,
  title={Torchbearer: A Model Fitting Library for PyTorch},
  author={Harris, Ethan and Painter, Matthew and Hare, Jonathon},
  journal={arXiv preprint arXiv:1809.03363},
  year={2018}
}
Parameters:
  • model (torch.nn.Module) – The base pytorch model
  • optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
  • criterion (func / None) – The final loss criterion that provides a loss value to the optimizer
  • metrics (list) – The list of torchbearer.Metric instances to process during fitting
  • callbacks (list) – The list of torchbearer.Callback instances to call during fitting
  • verbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
cpu()[source]

Moves all model parameters and buffers to the CPU.

Returns:self
Return type:Trial
cuda(device=None)[source]

Moves all model parameters and buffers to the GPU.

Parameters:device (int) – if specified, all parameters will be copied to that device
Returns:self
Return type:Trial
eval()[source]

Set model and metrics to evaluation mode

Returns:self
Return type:Trial
evaluate(verbose=-1, data_key=None)[source]

Evaluate this trial on the validation data.

Parameters:
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
  • data_key (StateKey) – Optional StateKey for the data to evaluate on. Default: torchbearer.VALIDATION_DATA
Returns:

The final metric values

Return type:

dict

for_inf_steps(train=True, val=True, test=True)[source]

Use this trail with infinite steps. Returns self so that methods can be chained for convenience.

Parameters:
  • train (bool) – Use an infinite number of training steps
  • val (bool) – Use an infinite number of validation steps
  • test (bool) – Use an infinite number of test steps
Returns:

self

Return type:

Trial

for_inf_test_steps()[source]

Use this trial with an infinite number of test steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.

Returns:self
Return type:Trial
for_inf_train_steps()[source]

Use this trial with an infinite number of training steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.

Returns:self
Return type:Trial
for_inf_val_steps()[source]

Use this trial with an infinite number of validation steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.

Returns:self
Return type:Trial
for_steps(train_steps=None, val_steps=None, test_steps=None)[source]

Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Parameters:
  • train_steps (int) – The number of training steps per epoch to run
  • val_steps (int) – The number of validation steps per epoch to run
  • test_steps (int) – The number of test steps per epoch to run (when using predict())
Returns:

self

Return type:

Trial

for_test_steps(steps)[source]

Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Parameters:steps (int) – The number of test steps per epoch to run (when using predict())
Returns:self
Return type:Trial
for_train_steps(steps)[source]

Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps is larger than dataset size then loader will be refreshed like if it was a new epoch. If steps is -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Parameters:steps (int) – The number of training steps per epoch to run.
Returns:self
Return type:Trial
for_val_steps(steps)[source]

Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Parameters:steps (int) – The number of validation steps per epoch to run
Returns:self
Return type:Trial
load_state_dict(state_dict, resume=True, **kwargs)[source]

Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.

Parameters:
  • state_dict (dict) – The state dict to reload
  • resume (bool) – If True, resume from the given state. Else, just load in the model weights.
  • kwargs – See: torch.nn.Module.load_state_dict
Returns:

self

Return type:

Trial

predict(verbose=-1, data_key=None)[source]

Determine predictions for this trial on the test data.

Parameters:
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
  • data_key (StateKey) – Optional StateKey for the data to predict on. Default: torchbearer.TEST_DATA
Returns:

Model outputs as a list

Return type:

list

replay(callbacks=[], verbose=2, one_batch=False)[source]

Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

Parameters:
  • callbacks (list) – List of callbacks to be run during the replay
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
  • one_batch (bool) – If True, only one batch per epoch is replayed. If False, all batches are replayed
Returns:

self

Return type:

Trial

run(epochs=1, verbose=-1)[source]

Run this trial for the given number of epochs, starting from the last trained epoch.

Parameters:
  • epochs (int, optional) – The number of epochs to run for
  • verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
  • If -1 (progress,) – Automatic
State Requirements:
Returns:The model history (list of tuple of steps summary and epoch metric dicts)
Return type:list
state_dict(**kwargs)[source]

Get a dict containing the model and optimizer states, as well as the model history.

Parameters:kwargs – See: torch.nn.Module.state_dict
Returns:A dict containing parameters and persistent buffers.
Return type:dict
to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

Parameters:
Returns:

self

Return type:

Trial

train()[source]

Set model and metrics to training mode.

Returns:self
Return type:Trial
with_closure(closure)[source]

Use this trial with custom closure

Parameters:closure (function) – Function of state that defines the custom closure
Returns:self:
Return type:Trial
with_generators(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]

Use this trial with the given generators. Returns self so that methods can be chained for convenience.

Parameters:
  • train_generator – The training data generator to use during calls to run()
  • val_generator – The validation data generator to use during calls to run() and evaluate()
  • test_generator – The testing data generator to use during calls to predict()
  • train_steps (int) – The number of steps per epoch to take when using the training generator
  • val_steps (int) – The number of steps per epoch to take when using the validation generator
  • test_steps (int) – The number of steps per epoch to take when using the testing generator
Returns:

self

Return type:

Trial

with_inf_train_loader()[source]

Use this trial with a training iterator that refreshes when it finishes instead of each epoch. This allows for setting training steps less than the size of the generator and model will still be trained on all training samples if enough “epochs” are run.

Returns:self:
Return type:Trial
with_test_data(x, batch_size=1, num_workers=1, steps=None)[source]

Use this trial with the given test data. Returns self so that methods can be chained for convenience.

Parameters:
  • x (torch.Tensor) – The test x data to use during calls to predict()
  • batch_size (int) – The size of each batch to sample from the data
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

with_test_generator(generator, steps=None)[source]

Use this trial with the given test generator. Returns self so that methods can be chained for convenience.

Parameters:
  • generator – The test data generator to use during calls to predict()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

with_train_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]

Use this trial with the given train data. Returns self so that methods can be chained for convenience.

Parameters:
  • x (torch.Tensor) – The train x data to use during calls to run()
  • y (torch.Tensor) – The train labels to use during calls to run()
  • batch_size (int) – The size of each batch to sample from the data
  • shuffle (bool) – If True, then data will be shuffled each epoch
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

with_train_generator(generator, steps=None)[source]

Use this trial with the given train generator. Returns self so that methods can be chained for convenience.

Parameters:
  • generator – The train data generator to use during calls to run()
  • steps (int) – The number of steps per epoch to take when using this generator.
Returns:

self

Return type:

Trial

with_val_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]

Use this trial with the given validation data. Returns self so that methods can be chained for convenience.

Parameters:
  • x (torch.Tensor) – The validation x data to use during calls to run() and evaluate()
  • y (torch.Tensor) – The validation labels to use during calls to run() and evaluate()
  • batch_size (int) – The size of each batch to sample from the data
  • shuffle (bool) – If True, then data will be shuffled each epoch
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

with_val_generator(generator, steps=None)[source]

Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.

Parameters:
  • generator – The validation data generator to use during calls to run() and evaluate()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

torchbearer.trial.deep_to(batch, device, dtype)[source]

Static method to call to() on tensors or tuples. All items in tuple will have deep_to() called

Parameters:
  • batch (tuple / list / torch.Tensor) – The mini-batch which requires a to() call
  • device (torch.device) – The desired device of the batch
  • dtype (torch.dtype) – The desired datatype of the batch
Returns:

The moved or casted batch

Return type:

tuple / list / torch.Tensor

torchbearer.trial.get_default(fcn, arg)[source]
torchbearer.trial.get_printer(verbose, validation_label_letter)[source]
torchbearer.trial.inject_callback(callback)[source]

Decorator to inject a callback into the callback list and remove the callback after the decorated function has executed

Parameters:callback (Callback) – Callback to be injected
Returns:The decorator
torchbearer.trial.inject_printer(validation_label_letter='v')[source]

The inject printer decorator is used to inject the appropriate printer callback, according to the verbosity level.

Parameters:validation_label_letter (str) – The validation label letter to use
Returns:A decorator
torchbearer.trial.inject_sampler(data_key, predict=False)[source]

Decorator to inject a Sampler into state[torchbearer.SAMPLER] along with the specified generator into state[torchbearer.GENERATOR] and number of steps into state[torchbearer.STEPS]

Parameters:
  • data_key (StateKey) – StateKey for the data to inject
  • predict (bool) – If true, the prediction batch loader is used, if false the standard data loader is used
Returns:

The decorator

torchbearer.trial.load_batch_infinite(loader)[source]

Wraps a batch loader and refreshes the iterator once it has been completed.

Parameters:loader – batch loader to wrap
torchbearer.trial.load_batch_none(state)[source]

Load a none (none, none) tuple mini-batch into state

Parameters:state (dict) – The current state dict of the Trial.
torchbearer.trial.load_batch_predict(state)[source]

Load a prediction (input data, target) or (input data) mini-batch from iterator into state

Parameters:state (dict) – The current state dict of the Trial.
torchbearer.trial.load_batch_standard(state)[source]

Load a standard (input data, target) tuple mini-batch from iterator into state

Parameters:state (dict) – The current state dict of the Trial.
torchbearer.trial.update_device_and_dtype(state, *args, **kwargs)[source]

Function get data type and device values from the args / kwargs and update state.

Parameters:
Returns:

device, dtype pair

State

The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced during model fitting. This module defines classes for interacting with state and all of the built in state keys used throughout torchbearer. The state_key() function can be used to create custom state keys for use in callbacks or metrics.

Example:

from torchbearer import state_key
MY_KEY = state_key('my_test_key')
torchbearer.state.BACKWARD_ARGS = backward_args

The optional arguments which should be passed to the backward call

torchbearer.state.BATCH = t

The current batch number

torchbearer.state.CALLBACK_LIST = callback_list

The CallbackList object which is called by the Trial

torchbearer.state.CRITERION = criterion

The criterion to use when model fitting

torchbearer.state.DATA = data

The string name of the current data

torchbearer.state.DATA_TYPE = dtype

The data type of tensors in use by the model, match this to avoid type issues

torchbearer.state.DEVICE = device

The device currently in use by the Trial and PyTorch model

torchbearer.state.EPOCH = epoch

The current epoch number

torchbearer.state.FINAL_PREDICTIONS = final_predictions

The key which maps to the predictions over the dataset when calling predict

torchbearer.state.GENERATOR = generator

The current data generator (DataLoader)

torchbearer.state.HISTORY = history

The history list of the Trial instance

torchbearer.state.INF_TRAIN_LOADING = inf_train_loading

Flag for refreshing of training iterator when finished instead of each epoch

torchbearer.state.INPUT = x

The current batch of inputs

torchbearer.state.ITERATOR = iterator

The current iterator

torchbearer.state.LOSS = loss

The current value for the loss

torchbearer.state.MAX_EPOCHS = max_epochs

The total number of epochs to run for

torchbearer.state.METRICS = metrics

The metric dict from the current batch of data

torchbearer.state.METRIC_LIST = metric_list

The list of metrics in use by the Trial

torchbearer.state.MODEL = model

The PyTorch module / model that will be trained

torchbearer.state.OPTIMIZER = optimizer

The optimizer to use when model fitting

torchbearer.state.PREDICTION = y_pred

The current batch of predictions

torchbearer.state.SAMPLER = sampler

The sampler which loads data from the generator onto the correct device

torchbearer.state.SELF = self

A self refrence to the Trial object for persistence etc.

torchbearer.state.STEPS = steps

The current number of steps per epoch

torchbearer.state.STOP_TRAINING = stop_training

A flag that can be set to true to stop the current fit call

class torchbearer.state.State[source]

State dictionary that behaves like a python dict but accepts StateKeys

data
get_key(statekey)[source]
update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

class torchbearer.state.StateKey(key)[source]

StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.

Parameters:key (str) – Base key
process(state)[source]

Process the state and update the metric for one iteration.

Parameters:args – Arguments given to the metric. If this is a root level metric, will be given state
Returns:None, or the value of the metric for this batch
process_final(state)[source]

Process the terminal state and output the final value of the metric.

Parameters:args – Arguments given to the metric. If this is a root level metric, will be given state
Returns:None or the value of the metric for this epoch
torchbearer.state.TARGET = y_true

The current batch of ground truth data

torchbearer.state.TEST_DATA = test_data

The flag representing test data

torchbearer.state.TEST_GENERATOR = test_generator

The test data generator in the Trial object

torchbearer.state.TEST_STEPS = test_steps

The number of test steps to take

torchbearer.state.TIMINGS = timings

The timings keys used by the timer callback

torchbearer.state.TRAIN_DATA = train_data

The flag representing train data

torchbearer.state.TRAIN_GENERATOR = train_generator

The train data generator in the Trial object

torchbearer.state.TRAIN_STEPS = train_steps

The number of train steps to take

torchbearer.state.VALIDATION_DATA = validation_data

The flag representing validation data

torchbearer.state.VALIDATION_GENERATOR = validation_generator

The validation data generator in the Trial object

torchbearer.state.VALIDATION_STEPS = validation_steps

The number of validation steps to take

torchbearer.state.VERSION = torchbearer_version

The torchbearer version

torchbearer.state.X = x

The current batch of inputs

torchbearer.state.Y_PRED = y_pred

The current batch of predictions

torchbearer.state.Y_TRUE = y_true

The current batch of ground truth data

torchbearer.state.state_key(key)[source]

Computes and returns a non-conflicting key for the state dictionary when given a seed key

Parameters:key (str) – The seed key - basis for new state key
Returns:New state key
Return type:StateKey

Utilities

class torchbearer.cv_utils.DatasetValidationSplitter(dataset_len, split_fraction, shuffle_seed=None)[source]
get_train_dataset(dataset)[source]

Creates a training dataset from existing dataset

Parameters:dataset (torch.utils.data.Dataset) – Dataset to be split into a training dataset
Returns:Training dataset split from whole dataset
Return type:torch.utils.data.Dataset
get_val_dataset(dataset)[source]

Creates a validation dataset from existing dataset

Args: dataset (torch.utils.data.Dataset): Dataset to be split into a validation dataset

Returns:Validation dataset split from whole dataset
Return type:torch.utils.data.Dataset
class torchbearer.cv_utils.SubsetDataset(dataset, ids)[source]
torchbearer.cv_utils.get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True)[source]

Generate validation and training datasets from whole dataset tensors

Parameters:
  • x (torch.Tensor) – Data tensor for dataset
  • y (torch.Tensor) – Label tensor for dataset
  • validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
  • validation_split (float) – Fraction of dataset to be used for validation
  • shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns:

Training and validation datasets

torchbearer.cv_utils.train_valid_splitter(x, y, split, shuffle=True)[source]

Generate training and validation tensors from whole dataset data and label tensors

Parameters:
  • x (torch.Tensor) – Data tensor for whole dataset
  • y (torch.Tensor) – Label tensor for whole dataset
  • split (float) – Fraction of dataset to be used for validation
  • shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns:

Training and validation tensors (training data, training labels, validation data, validation labels)