torchbearer

Trial

class torchbearer.Trial(model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2)[source]

The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.

Example:

>>> import torch
>>> from torchbearer import Trial

# Example trial that attempts to minimise the output of a linear layer.
# Makes use of a callback to input the random data at each batch and a loss that is the absolute value of the
# linear layer output. Runs for 10 iterations and a single epoch.
>>> model = torch.nn.Linear(2,1)
>>> optimiser = torch.optim.Adam(model.parameters(), lr=3e-4)

>>> @torchbearer.callbacks.on_sample
... def initial_data(state):
...     state[torchbearer.X] = torch.rand(1, 2)*10
>>> def minimise_output_loss(y_pred, y_true):
...     return torch.abs(y_pred)
>>> trial = Trial(model, optimiser, minimise_output_loss, ['loss'], [initial_data]).for_steps(10).run(1)
@article{2018torchbearer,
  title={Torchbearer: A Model Fitting Library for PyTorch},
  author={Harris, Ethan and Painter, Matthew and Hare, Jonathon},
  journal={arXiv preprint arXiv:1809.03363},
  year={2018}
}
Parameters:
  • model (torch.nn.Module) – The base pytorch model
  • optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
  • criterion (func / None) – The final loss criterion that provides a loss value to the optimizer
  • metrics (list) – The list of torchbearer.Metric instances to process during fitting
  • callbacks (list) – The list of torchbearer.Callback instances to call during fitting
  • verbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
for_train_steps(steps)[source]

Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps is larger than dataset size then loader will be refreshed like if it was a new epoch. If steps is -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Example:

# Simple trial that runs for 100 training iterations, in this case optimising nothing
>>> from torchbearer import Trial
>>> trial = Trial(None).for_train_steps(100)
Parameters:steps (int) – The number of training steps per epoch to run.
Returns:self
Return type:Trial
with_train_generator(generator, steps=None)[source]

Use this trial with the given train generator. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 100 training iterations on the MNIST dataset
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(MNIST('./data/', train=True))
>>> trial = Trial(None).with_train_generator(dataloader).for_steps(100).run(1)
Parameters:
  • generator – The train data generator to use during calls to run()
  • steps (int) – The number of steps per epoch to take when using this generator.
Returns:

self

Return type:

Trial

with_train_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]

Use this trial with the given train data. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 10 training iterations on some random data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> targets = torch.rand(10, 1)
>>> trial = Trial(None).with_val_data(data, targets).for_steps(10).run(1)
Parameters:
  • x (torch.Tensor) – The train x data to use during calls to run()
  • y (torch.Tensor) – The train labels to use during calls to run()
  • batch_size (int) – The size of each batch to sample from the data
  • shuffle (bool) – If True, then data will be shuffled each epoch
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

for_val_steps(steps)[source]

Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Example:

# Simple trial that runs for 10 validation iterations on no data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> trial = Trial(None).for_val_steps(10).run(1)
Parameters:steps (int) – The number of validation steps per epoch to run
Returns:self
Return type:Trial
with_val_generator(generator, steps=None)[source]

Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 100 validation iterations on the MNIST dataset
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(MNIST('./data/', train=False))
>>> trial = Trial(None).with_val_generator(dataloader).for_steps(100).run(1)
Parameters:
  • generator – The validation data generator to use during calls to run() and evaluate()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

with_val_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]

Use this trial with the given validation data. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 10 validation iterations on some random data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> targets = torch.rand(10, 1)
>>> trial = Trial(None).with_val_data(data, targets).for_steps(10).run(1)
Parameters:
  • x (torch.Tensor) – The validation x data to use during calls to run() and evaluate()
  • y (torch.Tensor) – The validation labels to use during calls to run() and evaluate()
  • batch_size (int) – The size of each batch to sample from the data
  • shuffle (bool) – If True, then data will be shuffled each epoch
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

for_test_steps(steps)[source]

Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Example:

# Simple trial that runs for 10 test iterations on some random data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
Parameters:steps (int) – The number of test steps per epoch to run (when using predict())
Returns:self
Return type:Trial
with_test_generator(generator, steps=None)[source]

Use this trial with the given test generator. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 10 test iterations on no data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
Parameters:
  • generator – The test data generator to use during calls to predict()
  • steps (int) – The number of steps per epoch to take when using this generator
Returns:

self

Return type:

Trial

with_test_data(x, batch_size=1, num_workers=1, steps=None)[source]

Use this trial with the given test data. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 10 test iterations on some random data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
Parameters:
  • x (torch.Tensor) – The test x data to use during calls to predict()
  • batch_size (int) – The size of each batch to sample from the data
  • num_workers (int) – Number of worker threads to use in the data loader
  • steps (int) – The number of steps per epoch to take when using this data
Returns:

self

Return type:

Trial

for_steps(train_steps=None, val_steps=None, test_steps=None)[source]

Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.

Example:

# Simple trial that runs for 10 training, validation and test iterations on some random data
>>> from torchbearer import Trial
>>> train_data = torch.rand(10, 1)
>>> val_data = torch.rand(10, 1)
>>> test_data = torch.rand(10, 1)
>>> trial = Trial(None).with_train_data(train_data).with_val_data(val_data).with_test_data(test_data)
>>> trial.for_steps(10, 10, 10).run(1)
Parameters:
  • train_steps (int) – The number of training steps per epoch to run
  • val_steps (int) – The number of validation steps per epoch to run
  • test_steps (int) – The number of test steps per epoch to run (when using predict())
Returns:

self

Return type:

Trial

with_generators(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]

Use this trial with the given generators. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 100 steps from a training and validation data generator
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> trainloader = DataLoader(MNIST('./data/', train=True))
>>> valloader = DataLoader(MNIST('./data/', train=False))
>>> trial = Trial(None).with_generators(trainloader, valloader, train_steps=100, val_steps=100).run(1)
Parameters:
  • train_generator – The training data generator to use during calls to run()
  • val_generator – The validation data generator to use during calls to run() and evaluate()
  • test_generator – The testing data generator to use during calls to predict()
  • train_steps (int) – The number of steps per epoch to take when using the training generator
  • val_steps (int) – The number of steps per epoch to take when using the validation generator
  • test_steps (int) – The number of steps per epoch to take when using the testing generator
Returns:

self

Return type:

Trial

with_data(x_train=None, y_train=None, x_val=None, y_val=None, x_test=None, batch_size=1, num_workers=1, train_steps=None, val_steps=None, test_steps=None, shuffle=True)[source]

Use this trial with the given data. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs for 10 test iterations on some random data
>>> from torchbearer import Trial
>>> data = torch.rand(10, 1)
>>> targets = torch.rand(10, 1)
>>> test_data = torch.rand(10, 1)
>>> trial = Trial(None).with_data(x_train=data, y_train=targets, x_test=test_data)
>>> trial.for_test_steps(10).run(1)
Parameters:
  • x_train (torch.Tensor) – The training data to use
  • y_train (torch.Tensor) – The training targets to use
  • x_val (torch.Tensor) – The validation data to use
  • y_val (torch.Tensor) – The validation targets to use
  • x_test (torch.Tensor) – The test data to use
  • batch_size (int) – Batch size to use in mini-batching
  • num_workers (int) – Number of workers to use for data loading and batching
  • train_steps (int) – Number of steps for each training pass
  • val_steps (int) – Number of steps for each validation pass
  • test_steps (int) – Number of steps for each test pass
  • shuffle (bool) – If True, shuffle training and validation data.
Returns:

self

Return type:

Trial

for_inf_train_steps()[source]

Use this trial with an infinite number of training steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs training data until stopped
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> trainloader = DataLoader(MNIST('./data/', train=True))
>>> trial = Trial(None).with_train_generator(trainloader).for_inf_train_steps()
>>> trial.run(1)
Returns:self
Return type:Trial
for_inf_val_steps()[source]

Use this trial with an infinite number of validation steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs validation data until stopped
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> valloader = DataLoader(MNIST('./data/', train=False))
>>> trial = Trial(None).with_val_generator(valloader).for_inf_val_steps()
>>> trial.run(1)
Returns:self
Return type:Trial
for_inf_test_steps()[source]

Use this trial with an infinite number of test steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs test data until stopped
>>> from torchbearer import Trial
>>> test_data = torch.rand(1000, 10)
>>> trial = Trial(None).with_test_data(test_data).for_inf_test_steps()
>>> trial.run(1)
Returns:self
Return type:Trial
for_inf_steps(train=True, val=True, test=True)[source]

Use this trail with infinite steps. Returns self so that methods can be chained for convenience.

Example:

# Simple trial that runs training and test data until stopped
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> trainloader = DataLoader(MNIST('./data/', train=True))
>>> valloader = DataLoader(MNIST('./data/', train=False))
>>> trial = Trial(None).with_train_generator(trainloader).for_inf_steps(valloader)
>>> trial.with_inf_test_loader(True, False, True).run(1)
Parameters:
  • train (bool) – Use an infinite number of training steps
  • val (bool) – Use an infinite number of validation steps
  • test (bool) – Use an infinite number of test steps
Returns:

self

Return type:

Trial

with_inf_train_loader()[source]

Use this trial with a training iterator that refreshes when it finishes instead of each epoch. This allows for setting training steps less than the size of the generator and model will still be trained on all training samples if enough “epochs” are run.

Example:

# Simple trial that runs 10 epochs of 100 iterations of a training generator without reshuffling until all data has been seen
>>> from torchbearer import Trial
>>> from torchvision.datasets import MNIST
>>> from torch.utils.data import DataLoader
>>> trainloader = DataLoader(MNIST('./data/', train=True))
>>> trial = Trial(None).with_train_generator(trainloader).with_inf_train_loader()
>>> trial.run(10)
Returns:self:
Return type:Trial
with_loader(batch_loader)[source]

Use this trial with custom batch loader. Usually calls next on state[torchbearer.ITERATOR] and populates state[torchbearer.X] and state[torchbearer.Y_TRUE]

Example:

# Simple trial that runs with a custom loader function that populates X and Y_TRUE in state with random data
>>> from torchbearer import Trial
>>> def custom_loader(state):
...     state[X], state[Y_TRUE] = torch.rand(5, 5), torch.rand(5, 5)
>>> trial = Trial(None).with_loader(custom_loader)
>>> trial.run(10)
Parameters:batch_loader (function) – Function of state that extracts data from data loader (stored under torchbearer.ITERATOR), stores it in state and sends it to the correct device
Returns:self:
Return type:Trial
with_closure(closure)[source]

Use this trial with custom closure

Example:

# Simple trial that runs with a custom closure
>>> from torchbearer import Trial
>>> def custom_closure(state):
...     print(state[torchbearer.BATCH])
>>> trial = Trial(None).with_closure(custom_closure).for_steps(3)
>>> _ = trial.run(1)
0
1
2
Parameters:closure (function) – Function of state that defines the custom closure
Returns:self:
Return type:Trial
run(epochs=1, verbose=-1)[source]

Run this trial for the given number of epochs, starting from the last trained epoch.

Example:

# Simple trial that runs with a custom closure
>>> from torchbearer import Trial
>>> trial = Trial(None).for_steps(100)
>>> _ = trial.run(1)
Parameters:
  • epochs (int, optional) – The number of epochs to run for
  • verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
State Requirements:
  • torchbearer.state.MODEL: Model should be callable and not none, set on Trial init
Returns:The model history (list of tuple of steps summary and epoch metric dicts)
Return type:list
evaluate(verbose=-1, data_key=None)[source]

Evaluate this trial on the validation data.

Example:

# Simple trial to evaluate on both validation and test data
>>> from torchbearer import Trial
>>> test_data = torch.rand(5, 5)
>>> val_data = torch.rand(5, 5)
>>> t = Trial(None).with_val_data(val_data).with_test_data(test_data)
>>> t.evaluate(data_key=torchbearer.VALIDATION_DATA).evaluate(data_key=torchbearer.TEST_DATA)
Parameters:
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
  • data_key (StateKey) – Optional StateKey for the data to evaluate on. Default: torchbearer.VALIDATION_DATA
Returns:

The final metric values

Return type:

dict

predict(verbose=-1, data_key=None)[source]

Determine predictions for this trial on the test data.

Example:

# Simple trial to predict on some validation and test data
>>> from torchbearer import Trial
>>> val_data = torch.rand(5, 5)
>>> test_data = torch.rand(5, 5)
>>> t = Trial(None).with_test_data(test_data)
>>> test_predictions = t.predict(data_key=torchbearer.TEST_DATA)
Parameters:
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
  • data_key (StateKey) – Optional StateKey for the data to predict on. Default: torchbearer.TEST_DATA
Returns:

Model outputs as a list

Return type:

list

replay(callbacks=None, verbose=2, one_batch=False)[source]

Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

Example:

>>> from torchbearer import Trial
>>> state = torch.load('some_state.pt')
>>> t = Trial(None).load_state_dict(state)
>>> t.replay()
Parameters:
  • callbacks (list) – List of callbacks to be run during the replay
  • verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
  • one_batch (bool) – If True, only one batch per epoch is replayed. If False, all batches are replayed
Returns:

self

Return type:

Trial

train()[source]

Set model and metrics to training mode.

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None).train()
Returns:self
Return type:Trial
eval()[source]

Set model and metrics to evaluation mode

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None).eval()
Returns:self
Return type:Trial
to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None).to('cuda:1')
Parameters:
Returns:

self

Return type:

Trial

cuda(device=None)[source]

Moves all model parameters and buffers to the GPU.

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None).cuda()
Parameters:device (int) – if specified, all parameters will be copied to that device
Returns:self
Return type:Trial
cpu()[source]

Moves all model parameters and buffers to the CPU.

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None).cpu()
Returns:self
Return type:Trial
state_dict(**kwargs)[source]

Get a dict containing the model and optimizer states, as well as the model history.

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None)
>>> state = t.state_dict() # State dict that can now be saved with torch.save
Parameters:kwargs – See: torch.nn.Module.state_dict
Returns:A dict containing parameters and persistent buffers.
Return type:dict
load_state_dict(state_dict, resume=True, **kwargs)[source]

Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.

Example: ::
>>> from torchbearer import Trial
>>> t = Trial(None)
>>> state = torch.load('some_state.pt')
>>> t.load_state_dict(state)
Parameters:
  • state_dict (dict) – The state dict to reload
  • resume (bool) – If True, resume from the given state. Else, just load in the model weights.
  • kwargs – See: torch.nn.Module.load_state_dict
Returns:

self

Return type:

Trial

Batch Loaders

torchbearer.trial.load_batch_infinite(loader)[source]

Wraps a batch loader and refreshes the iterator once it has been completed.

Parameters:loader – batch loader to wrap
torchbearer.trial.load_batch_none(state)[source]

Load a none (none, none) tuple mini-batch into state

Parameters:state (dict) – The current state dict of the Trial.
torchbearer.trial.load_batch_predict(state)[source]

Load a prediction (input data, target) or (input data) mini-batch from iterator into state

Parameters:state (dict) – The current state dict of the Trial.
torchbearer.trial.load_batch_standard(state)[source]

Load a standard (input data, target) tuple mini-batch from iterator into state

Parameters:state (dict) – The current state dict of the Trial.

Misc

torchbearer.trial.deep_to(batch, device, dtype)[source]

Static method to call to() on tensors, tuples or dicts. All items will have deep_to() called

Example:

>>> import torch
>>> from torchbearer import deep_to
>>> example_dict = {'a': torch.ones(5)*2.1, 'b': torch.ones(1)*5.9}
>>> deep_to(example_dict, device='cpu', dtype=torch.int)
{'a': tensor([2, 2, 2, 2, 2], dtype=torch.int32), 'b': tensor([5], dtype=torch.int32)}
Parameters:
  • batch (tuple / list / torch.Tensor / dict) – The mini-batch which requires a to() call
  • device (torch.device) – The desired device of the batch
  • dtype (torch.dtype) – The desired datatype of the batch
Returns:

The moved or casted batch

Return type:

tuple / list / torch.Tensor

torchbearer.trial.update_device_and_dtype(state, *args, **kwargs)[source]

Function gets data type and device values from the args / kwargs and updates state.

Parameters:
  • state (State) – The State to update
  • args – Arguments to the Trial.to() function
  • kwargs – Keyword arguments to the Trial.to() function
Returns:

state

State

The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced during model fitting. This module defines classes for interacting with state and all of the built in state keys used throughout torchbearer. The state_key() function can be used to create custom state keys for use in callbacks or metrics.

Example:

>>> from torchbearer import state_key
>>> MY_KEY = state_key('my_test_key')

State

class torchbearer.state.State[source]

State dictionary that behaves like a python dict but accepts StateKeys

data
get_key(statekey)[source]
update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

class torchbearer.state.StateKey(key)[source]

StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.

Parameters:key (str) – Base key
process(state)[source]

Process the state and update the metric for one iteration.

Parameters:args – Arguments given to the metric. If this is a root level metric, will be given state
Returns:None, or the value of the metric for this batch
process_final(state)[source]

Process the terminal state and output the final value of the metric.

Parameters:args – Arguments given to the metric. If this is a root level metric, will be given state
Returns:None or the value of the metric for this epoch
torchbearer.state.state_key(key)[source]

Computes and returns a non-conflicting key for the state dictionary when given a seed key

Parameters:key (str) – The seed key - basis for new state key
Returns:New state key
Return type:StateKey

Key List

torchbearer.state.BACKWARD_ARGS = backward_args

The optional arguments which should be passed to the backward call

torchbearer.state.BATCH = t

The current batch number

torchbearer.state.CALLBACK_LIST = callback_list

The CallbackList object which is called by the Trial

torchbearer.state.CRITERION = criterion

The criterion to use when model fitting

torchbearer.state.DATA = data

The string name of the current data

torchbearer.state.DATA_TYPE = dtype

The data type of tensors in use by the model, match this to avoid type issues

torchbearer.state.DEVICE = device

The device currently in use by the Trial and PyTorch model

torchbearer.state.EPOCH = epoch

The current epoch number

torchbearer.state.FINAL_PREDICTIONS = final_predictions

The key which maps to the predictions over the dataset when calling predict

torchbearer.state.GENERATOR = generator

The current data generator (DataLoader)

torchbearer.state.HISTORY = history

The history list of the Trial instance

torchbearer.state.INF_TRAIN_LOADING = inf_train_loading

Flag for refreshing of training iterator when finished instead of each epoch

torchbearer.state.INPUT = x

The current batch of inputs

torchbearer.state.ITERATOR = iterator

The current iterator

torchbearer.state.LOADER = loader

The batch loader which handles formatting data from each batch

torchbearer.state.LOSS = loss

The current value for the loss

torchbearer.state.MAX_EPOCHS = max_epochs

The total number of epochs to run for

torchbearer.state.METRICS = metrics

The metric dict from the current batch of data

torchbearer.state.METRIC_LIST = metric_list

The list of metrics in use by the Trial

torchbearer.state.MIXUP_LAMBDA = mixup_lambda

The lambda coefficient of the linear combination of inputs

torchbearer.state.MIXUP_PERMUTATION = mixup_permutation

The permutation of input indices for input mixup

torchbearer.state.MODEL = model

The PyTorch module / model that will be trained

torchbearer.state.OPTIMIZER = optimizer

The optimizer to use when model fitting

torchbearer.state.PREDICTION = y_pred

The current batch of predictions

torchbearer.state.SAMPLER = sampler

The sampler which loads data from the generator onto the correct device

torchbearer.state.SELF = self

A self refrence to the Trial object for persistence etc.

torchbearer.state.STEPS = steps

The current number of steps per epoch

torchbearer.state.STOP_TRAINING = stop_training

A flag that can be set to true to stop the current fit call

torchbearer.state.TARGET = y_true

The current batch of ground truth data

torchbearer.state.TEST_DATA = test_data

The flag representing test data

torchbearer.state.TEST_GENERATOR = test_generator

The test data generator in the Trial object

torchbearer.state.TEST_STEPS = test_steps

The number of test steps to take

torchbearer.state.TIMINGS = timings

The timings keys used by the timer callback

torchbearer.state.TRAIN_DATA = train_data

The flag representing train data

torchbearer.state.TRAIN_GENERATOR = train_generator

The train data generator in the Trial object

torchbearer.state.TRAIN_STEPS = train_steps

The number of train steps to take

torchbearer.state.VALIDATION_DATA = validation_data

The flag representing validation data

torchbearer.state.VALIDATION_GENERATOR = validation_generator

The validation data generator in the Trial object

torchbearer.state.VALIDATION_STEPS = validation_steps

The number of validation steps to take

torchbearer.state.VERSION = torchbearer_version

The torchbearer version

torchbearer.state.X = x

The current batch of inputs

torchbearer.state.Y_PRED = y_pred

The current batch of predictions

torchbearer.state.Y_TRUE = y_true

The current batch of ground truth data

Utilities

class torchbearer.cv_utils.DatasetValidationSplitter(dataset_len, split_fraction, shuffle_seed=None)[source]

Generates training and validation split indicies for a given dataset length and creates training and validation datasets using these splits

Parameters:
  • dataset_len – The length of the dataset to be split into training and validation
  • split_fraction – The fraction of the whole dataset to be used for validation
  • shuffle_seed – Optional random seed for the shuffling process
get_train_dataset(dataset)[source]

Creates a training dataset from existing dataset

Parameters:dataset (torch.utils.data.Dataset) – Dataset to be split into a training dataset
Returns:Training dataset split from whole dataset
Return type:torch.utils.data.Dataset
get_val_dataset(dataset)[source]

Creates a validation dataset from existing dataset

Args: dataset (torch.utils.data.Dataset): Dataset to be split into a validation dataset

Returns:Validation dataset split from whole dataset
Return type:torch.utils.data.Dataset
class torchbearer.cv_utils.SubsetDataset(dataset, ids)[source]

Dataset that consists of a subset of a previous dataset

Parameters:
  • dataset (torch.utils.data.Dataset) – Complete dataset
  • ids (list) – List of subset IDs
torchbearer.cv_utils.get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True)[source]

Generate validation and training datasets from whole dataset tensors

Parameters:
  • x (torch.Tensor) – Data tensor for dataset
  • y (torch.Tensor) – Label tensor for dataset
  • validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
  • validation_split (float) – Fraction of dataset to be used for validation
  • shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns:

Training and validation datasets

torchbearer.cv_utils.train_valid_splitter(x, y, split, shuffle=True)[source]

Generate training and validation tensors from whole dataset data and label tensors

Parameters:
  • x (torch.Tensor) – Data tensor for whole dataset
  • y (torch.Tensor) – Label tensor for whole dataset
  • split (float) – Fraction of dataset to be used for validation
  • shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns:

Training and validation tensors (training data, training labels, validation data, validation labels)

torchbearer.bases.base_closure(x, model, y_pred, y_true, crit, loss, opt)[source]

Function to create a standard pytorch closure using objects taken from state under the given keys.

Parameters:
  • x – State key under which the input data is stored
  • model – State key under which the pytorch model is stored
  • y_pred – State key under which the predictions will be stored
  • y_true – State key under which the targets are stored
  • crit – State key under which the criterion function is stored (function of state or (y_pred, y_true))
  • loss – State key under which the loss will be stored
  • opt – State key under which the optimsiser is stored
Returns:

Standard closure function

Return type:

function