torchbearer¶
Trial¶
-
class
torchbearer.trial.
CallbackListInjection
(callback, callback_list)[source]¶ This class allows for an callback to be injected into a callback list, without masking the methods available for mutating the list. In this way, callbacks (such as printers) can be injected seamlessly into the methods of the trial class.
Parameters: - callback (Callback) – The
Callback
to inject - callback_list (CallbackList) – The underlying
CallbackList
-
load_state_dict
(state_dict)[source]¶ Resume this callback list from the given state. Callbacks must be given in the same order for this to work.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: CallbackList
- callback (Callback) – The
-
class
torchbearer.trial.
MockOptimizer
[source]¶ The Mock Optimizer will be used inplace of an optimizer in the event that none is passed to the Trial class.
-
class
torchbearer.trial.
Sampler
(batch_loader)[source]¶ Sampler wraps a batch loader function and executes it when
Sampler.sample()
is calledParameters: batch_loader (func) – The batch loader to execute
-
class
torchbearer.trial.
Trial
(model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2)[source]¶ The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.
@article{2018torchbearer, title={Torchbearer: A Model Fitting Library for PyTorch}, author={Harris, Ethan and Painter, Matthew and Hare, Jonathon}, journal={arXiv preprint arXiv:1809.03363}, year={2018} }
Parameters: - model (torch.nn.Module) – The base pytorch model
- optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
- criterion (func / None) – The final loss criterion that provides a loss value to the optimizer
- metrics (list) – The list of
torchbearer.Metric
instances to process during fitting - callbacks (list) – The list of
torchbearer.Callback
instances to call during fitting - verbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
-
cuda
(device=None)[source]¶ Moves all model parameters and buffers to the GPU.
Parameters: device (int) – if specified, all parameters will be copied to that device Returns: self Return type: Trial
-
evaluate
(verbose=-1, data_key=None)[source]¶ Evaluate this trial on the validation data.
Parameters: Returns: The final metric values
Return type: dict
-
for_inf_steps
(train=True, val=True, test=True)[source]¶ Use this trail with infinite steps. Returns self so that methods can be chained for convenience.
Parameters: - train (bool) – Use an infinite number of training steps
- val (bool) – Use an infinite number of validation steps
- test (bool) – Use an infinite number of test steps
Returns: self
Return type:
-
for_inf_test_steps
()[source]¶ Use this trial with an infinite number of test steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Returns: self Return type: Trial
-
for_inf_train_steps
()[source]¶ Use this trial with an infinite number of training steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Returns: self Return type: Trial
-
for_inf_val_steps
()[source]¶ Use this trial with an infinite number of validation steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Returns: self Return type: Trial
-
for_steps
(train_steps=None, val_steps=None, test_steps=None)[source]¶ Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Parameters: - train_steps (int) – The number of training steps per epoch to run
- val_steps (int) – The number of validation steps per epoch to run
- test_steps (int) – The number of test steps per epoch to run (when using
predict()
)
Returns: self
Return type:
-
for_test_steps
(steps)[source]¶ Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Parameters: steps (int) – The number of test steps per epoch to run (when using predict()
)Returns: self Return type: Trial
-
for_train_steps
(steps)[source]¶ Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps is larger than dataset size then loader will be refreshed like if it was a new epoch. If steps is -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Parameters: steps (int) – The number of training steps per epoch to run. Returns: self Return type: Trial
-
for_val_steps
(steps)[source]¶ Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Parameters: steps (int) – The number of validation steps per epoch to run Returns: self Return type: Trial
-
load_state_dict
(state_dict, resume=True, **kwargs)[source]¶ Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.
Parameters: - state_dict (dict) – The state dict to reload
- resume (bool) – If True, resume from the given state. Else, just load in the model weights.
- kwargs – See: torch.nn.Module.load_state_dict
Returns: self
Return type:
-
predict
(verbose=-1, data_key=None)[source]¶ Determine predictions for this trial on the test data.
Parameters: Returns: Model outputs as a list
Return type: list
-
replay
(callbacks=[], verbose=2, one_batch=False)[source]¶ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.
Parameters: - callbacks (list) – List of callbacks to be run during the replay
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
- one_batch (bool) – If True, only one batch per epoch is replayed. If False, all batches are replayed
Returns: self
Return type:
-
run
(epochs=1, verbose=-1)[source]¶ Run this trial for the given number of epochs, starting from the last trained epoch.
Parameters: - epochs (int, optional) – The number of epochs to run for
- verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
- If -1 (progress,) – Automatic
- State Requirements:
torchbearer.state.MODEL
: Model should be callable and not none, set on Trial init
Returns: The model history (list of tuple of steps summary and epoch metric dicts) Return type: list
-
state_dict
(**kwargs)[source]¶ Get a dict containing the model and optimizer states, as well as the model history.
Parameters: kwargs – See: torch.nn.Module.state_dict Returns: A dict containing parameters and persistent buffers. Return type: dict
-
to
(*args, **kwargs)[source]¶ Moves and/or casts the parameters and buffers.
Parameters: - args – See: torch.nn.Module.to
- kwargs –
See: torch.nn.Module.to
Returns: self
Return type:
-
with_closure
(closure)[source]¶ Use this trial with custom closure
Parameters: closure (function) – Function of state that defines the custom closure Returns: self: Return type: Trial
-
with_generators
(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]¶ Use this trial with the given generators. Returns self so that methods can be chained for convenience.
Parameters: - train_generator – The training data generator to use during calls to
run()
- val_generator – The validation data generator to use during calls to
run()
andevaluate()
- test_generator – The testing data generator to use during calls to
predict()
- train_steps (int) – The number of steps per epoch to take when using the training generator
- val_steps (int) – The number of steps per epoch to take when using the validation generator
- test_steps (int) – The number of steps per epoch to take when using the testing generator
Returns: self
Return type: - train_generator – The training data generator to use during calls to
-
with_inf_train_loader
()[source]¶ Use this trial with a training iterator that refreshes when it finishes instead of each epoch. This allows for setting training steps less than the size of the generator and model will still be trained on all training samples if enough “epochs” are run.
Returns: self: Return type: Trial
-
with_test_data
(x, batch_size=1, num_workers=1, steps=None)[source]¶ Use this trial with the given test data. Returns self so that methods can be chained for convenience.
Parameters: - x (torch.Tensor) – The test x data to use during calls to
predict()
- batch_size (int) – The size of each batch to sample from the data
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The test x data to use during calls to
-
with_test_generator
(generator, steps=None)[source]¶ Use this trial with the given test generator. Returns self so that methods can be chained for convenience.
Parameters: - generator – The test data generator to use during calls to
predict()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator – The test data generator to use during calls to
-
with_train_data
(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]¶ Use this trial with the given train data. Returns self so that methods can be chained for convenience.
Parameters: - x (torch.Tensor) – The train x data to use during calls to
run()
- y (torch.Tensor) – The train labels to use during calls to
run()
- batch_size (int) – The size of each batch to sample from the data
- shuffle (bool) – If True, then data will be shuffled each epoch
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The train x data to use during calls to
-
with_train_generator
(generator, steps=None)[source]¶ Use this trial with the given train generator. Returns self so that methods can be chained for convenience.
Parameters: - generator – The train data generator to use during calls to
run()
- steps (int) – The number of steps per epoch to take when using this generator.
Returns: self
Return type: - generator – The train data generator to use during calls to
-
with_val_data
(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]¶ Use this trial with the given validation data. Returns self so that methods can be chained for convenience.
Parameters: - x (torch.Tensor) – The validation x data to use during calls to
run()
andevaluate()
- y (torch.Tensor) – The validation labels to use during calls to
run()
andevaluate()
- batch_size (int) – The size of each batch to sample from the data
- shuffle (bool) – If True, then data will be shuffled each epoch
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The validation x data to use during calls to
-
with_val_generator
(generator, steps=None)[source]¶ Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.
Parameters: - generator – The validation data generator to use during calls to
run()
andevaluate()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator – The validation data generator to use during calls to
-
torchbearer.trial.
deep_to
(batch, device, dtype)[source]¶ Static method to call
to()
on tensors or tuples. All items in tuple will havedeep_to()
calledParameters: - batch (tuple / list / torch.Tensor) – The mini-batch which requires a
to()
call - device (torch.device) – The desired device of the batch
- dtype (torch.dtype) – The desired datatype of the batch
Returns: The moved or casted batch
Return type: tuple / list / torch.Tensor
- batch (tuple / list / torch.Tensor) – The mini-batch which requires a
-
torchbearer.trial.
inject_callback
(callback)[source]¶ Decorator to inject a callback into the callback list and remove the callback after the decorated function has executed
Parameters: callback (Callback) – Callback
to be injectedReturns: The decorator
-
torchbearer.trial.
inject_printer
(validation_label_letter='v')[source]¶ The inject printer decorator is used to inject the appropriate printer callback, according to the verbosity level.
Parameters: validation_label_letter (str) – The validation label letter to use Returns: A decorator
-
torchbearer.trial.
inject_sampler
(data_key, predict=False)[source]¶ Decorator to inject a
Sampler
into state[torchbearer.SAMPLER] along with the specified generator into state[torchbearer.GENERATOR] and number of steps into state[torchbearer.STEPS]Parameters: Returns: The decorator
-
torchbearer.trial.
load_batch_infinite
(loader)[source]¶ Wraps a batch loader and refreshes the iterator once it has been completed.
Parameters: loader – batch loader to wrap
-
torchbearer.trial.
load_batch_none
(state)[source]¶ Load a none (none, none) tuple mini-batch into state
Parameters: state (dict) – The current state dict of the Trial
.
-
torchbearer.trial.
load_batch_predict
(state)[source]¶ Load a prediction (input data, target) or (input data) mini-batch from iterator into state
Parameters: state (dict) – The current state dict of the Trial
.
-
torchbearer.trial.
load_batch_standard
(state)[source]¶ Load a standard (input data, target) tuple mini-batch from iterator into state
Parameters: state (dict) – The current state dict of the Trial
.
-
torchbearer.trial.
update_device_and_dtype
(state, *args, **kwargs)[source]¶ Function get data type and device values from the args / kwargs and update state.
Parameters: - state (State) – The
State
to update - args – Arguments to the
Trial.to()
function - kwargs – Keyword arguments to the
Trial.to()
function
Returns: device, dtype pair
- state (State) – The
State¶
The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced
during model fitting. This module defines classes for interacting with state and all of the built in state keys used
throughout torchbearer. The state_key()
function can be used to create custom state keys for use in callbacks or
metrics.
Example:
from torchbearer import state_key
MY_KEY = state_key('my_test_key')
-
torchbearer.state.
BACKWARD_ARGS
= backward_args¶ The optional arguments which should be passed to the backward call
-
torchbearer.state.
BATCH
= t¶ The current batch number
-
torchbearer.state.
CALLBACK_LIST
= callback_list¶ The
CallbackList
object which is called by the Trial
-
torchbearer.state.
CRITERION
= criterion¶ The criterion to use when model fitting
-
torchbearer.state.
DATA
= data¶ The string name of the current data
-
torchbearer.state.
DATA_TYPE
= dtype¶ The data type of tensors in use by the model, match this to avoid type issues
-
torchbearer.state.
EPOCH
= epoch¶ The current epoch number
-
torchbearer.state.
FINAL_PREDICTIONS
= final_predictions¶ The key which maps to the predictions over the dataset when calling predict
-
torchbearer.state.
GENERATOR
= generator¶ The current data generator (DataLoader)
-
torchbearer.state.
HISTORY
= history¶ The history list of the Trial instance
-
torchbearer.state.
INF_TRAIN_LOADING
= inf_train_loading¶ Flag for refreshing of training iterator when finished instead of each epoch
-
torchbearer.state.
INPUT
= x¶ The current batch of inputs
-
torchbearer.state.
ITERATOR
= iterator¶ The current iterator
-
torchbearer.state.
LOSS
= loss¶ The current value for the loss
-
torchbearer.state.
MAX_EPOCHS
= max_epochs¶ The total number of epochs to run for
-
torchbearer.state.
METRICS
= metrics¶ The metric dict from the current batch of data
-
torchbearer.state.
MODEL
= model¶ The PyTorch module / model that will be trained
-
torchbearer.state.
OPTIMIZER
= optimizer¶ The optimizer to use when model fitting
-
torchbearer.state.
PREDICTION
= y_pred¶ The current batch of predictions
-
torchbearer.state.
SAMPLER
= sampler¶ The sampler which loads data from the generator onto the correct device
-
torchbearer.state.
SELF
= self¶ A self refrence to the Trial object for persistence etc.
-
torchbearer.state.
STEPS
= steps¶ The current number of steps per epoch
-
torchbearer.state.
STOP_TRAINING
= stop_training¶ A flag that can be set to true to stop the current fit call
-
class
torchbearer.state.
State
[source]¶ State dictionary that behaves like a python dict but accepts StateKeys
-
data
¶
-
-
class
torchbearer.state.
StateKey
(key)[source]¶ StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.
Parameters: key (str) – Base key
-
torchbearer.state.
TARGET
= y_true¶ The current batch of ground truth data
-
torchbearer.state.
TEST_DATA
= test_data¶ The flag representing test data
-
torchbearer.state.
TEST_GENERATOR
= test_generator¶ The test data generator in the Trial object
-
torchbearer.state.
TEST_STEPS
= test_steps¶ The number of test steps to take
-
torchbearer.state.
TIMINGS
= timings¶ The timings keys used by the timer callback
-
torchbearer.state.
TRAIN_DATA
= train_data¶ The flag representing train data
-
torchbearer.state.
TRAIN_GENERATOR
= train_generator¶ The train data generator in the Trial object
-
torchbearer.state.
TRAIN_STEPS
= train_steps¶ The number of train steps to take
-
torchbearer.state.
VALIDATION_DATA
= validation_data¶ The flag representing validation data
-
torchbearer.state.
VALIDATION_GENERATOR
= validation_generator¶ The validation data generator in the Trial object
-
torchbearer.state.
VALIDATION_STEPS
= validation_steps¶ The number of validation steps to take
-
torchbearer.state.
VERSION
= torchbearer_version¶ The torchbearer version
-
torchbearer.state.
X
= x¶ The current batch of inputs
-
torchbearer.state.
Y_PRED
= y_pred¶ The current batch of predictions
-
torchbearer.state.
Y_TRUE
= y_true¶ The current batch of ground truth data
Utilities¶
-
class
torchbearer.cv_utils.
DatasetValidationSplitter
(dataset_len, split_fraction, shuffle_seed=None)[source]¶
-
torchbearer.cv_utils.
get_train_valid_sets
(x, y, validation_data, validation_split, shuffle=True)[source]¶ Generate validation and training datasets from whole dataset tensors
Parameters: - x (torch.Tensor) – Data tensor for dataset
- y (torch.Tensor) – Label tensor for dataset
- validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
- validation_split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation datasets
-
torchbearer.cv_utils.
train_valid_splitter
(x, y, split, shuffle=True)[source]¶ Generate training and validation tensors from whole dataset data and label tensors
Parameters: - x (torch.Tensor) – Data tensor for whole dataset
- y (torch.Tensor) – Label tensor for whole dataset
- split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation tensors (training data, training labels, validation data, validation labels)