torchbearer¶
Trial¶
-
class
torchbearer.
Trial
(model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2)[source]¶ The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.
Example:
>>> import torch >>> from torchbearer import Trial # Example trial that attempts to aims the output of a linear layer. # Makes use of a callback to input the random data at each batch and a loss that is the absolute value of the # linear layer output. Runs for 10 iterations and a single epoch. >>> model = torch.nn.Linear(2,1) >>> optimiser = torch.optim.Adam(model.parameters(), lr=3e-4) >>> @torchbearer.callbacks.on_sample ... def initial_data(state): ... state[torchbearer.X] = torch.rand(1, 2)*10 >>> def minimise_output_loss(y_pred, y_true): ... return torch.abs(y_pred) >>> trial = Trial(model, optimiser, minimise_output_loss, ['loss'], [initial_data]).for_steps(10).run(1)
@article{2018torchbearer, title={Torchbearer: A Model Fitting Library for PyTorch}, author={Harris, Ethan and Painter, Matthew and Hare, Jonathon}, journal={arXiv preprint arXiv:1809.03363}, year={2018} }
Parameters: - model (torch.nn.Module) – The base pytorch model
- optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
- criterion (func / None) – The final loss criterion that provides a loss value to the optimizer
- metrics (list) – The list of
torchbearer.Metric
instances to process during fitting - callbacks (list) – The list of
torchbearer.Callback
instances to call during fitting - verbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
-
for_train_steps
(steps)[source]¶ Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps is larger than dataset size then loader will be refreshed like if it was a new epoch. If steps is -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 100 training iterations, in this case optimising nothing >>> from torchbearer import Trial >>> trial = Trial(None).for_train_steps(100)
Parameters: steps (int) – The number of training steps per epoch to run. Returns: self Return type: Trial
-
with_train_generator
(generator, steps=None)[source]¶ Use this trial with the given train generator. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 100 training iterations on the MNIST dataset >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(MNIST('./data/', train=True)) >>> trial = Trial(None).with_train_generator(dataloader).for_steps(100).run(1)
Parameters: - generator – The train data generator to use during calls to
run()
- steps (int) – The number of steps per epoch to take when using this generator.
Returns: self
Return type: - generator – The train data generator to use during calls to
-
with_train_data
(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]¶ Use this trial with the given train data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 training iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> targets = torch.rand(10, 1) >>> trial = Trial(None).with_val_data(data, targets).for_steps(10).run(1)
Parameters: - x (torch.Tensor) – The train x data to use during calls to
run()
- y (torch.Tensor) – The train labels to use during calls to
run()
- batch_size (int) – The size of each batch to sample from the data
- shuffle (bool) – If True, then data will be shuffled each epoch
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The train x data to use during calls to
-
for_val_steps
(steps)[source]¶ Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 10 validation iterations on no data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).for_val_steps(10).run(1)
Parameters: steps (int) – The number of validation steps per epoch to run Returns: self Return type: Trial
-
with_val_generator
(generator, steps=None)[source]¶ Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 100 validation iterations on the MNIST dataset >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_val_generator(dataloader).for_steps(100).run(1)
Parameters: - generator – The validation data generator to use during calls to
run()
andevaluate()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator – The validation data generator to use during calls to
-
with_val_data
(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]¶ Use this trial with the given validation data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 validation iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> targets = torch.rand(10, 1) >>> trial = Trial(None).with_val_data(data, targets).for_steps(10).run(1)
Parameters: - x (torch.Tensor) – The validation x data to use during calls to
run()
andevaluate()
- y (torch.Tensor) – The validation labels to use during calls to
run()
andevaluate()
- batch_size (int) – The size of each batch to sample from the data
- shuffle (bool) – If True, then data will be shuffled each epoch
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The validation x data to use during calls to
-
for_test_steps
(steps)[source]¶ Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 10 test iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
Parameters: steps (int) – The number of test steps per epoch to run (when using predict()
)Returns: self Return type: Trial
-
with_test_generator
(generator, steps=None)[source]¶ Use this trial with the given test generator. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 test iterations on no data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
Parameters: - generator – The test data generator to use during calls to
predict()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator – The test data generator to use during calls to
-
with_test_data
(x, batch_size=1, num_workers=1, steps=None)[source]¶ Use this trial with the given test data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 test iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
Parameters: - x (torch.Tensor) – The test x data to use during calls to
predict()
- batch_size (int) – The size of each batch to sample from the data
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The test x data to use during calls to
-
for_steps
(train_steps=None, val_steps=None, test_steps=None)[source]¶ Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 10 training, validation and test iterations on some random data >>> from torchbearer import Trial >>> train_data = torch.rand(10, 1) >>> val_data = torch.rand(10, 1) >>> test_data = torch.rand(10, 1) >>> trial = Trial(None).with_train_data(train_data).with_val_data(val_data).with_test_data(test_data) >>> trial.for_steps(10, 10, 10).run(1)
Parameters: - train_steps (int) – The number of training steps per epoch to run
- val_steps (int) – The number of validation steps per epoch to run
- test_steps (int) – The number of test steps per epoch to run (when using
predict()
)
Returns: self
Return type:
-
with_generators
(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]¶ Use this trial with the given generators. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 100 steps from a training and validation data generator >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> valloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_generators(trainloader, valloader, train_steps=100, val_steps=100).run(1)
Parameters: - train_generator – The training data generator to use during calls to
run()
- val_generator – The validation data generator to use during calls to
run()
andevaluate()
- test_generator – The testing data generator to use during calls to
predict()
- train_steps (int) – The number of steps per epoch to take when using the training generator
- val_steps (int) – The number of steps per epoch to take when using the validation generator
- test_steps (int) – The number of steps per epoch to take when using the testing generator
Returns: self
Return type: - train_generator – The training data generator to use during calls to
-
with_data
(x_train=None, y_train=None, x_val=None, y_val=None, x_test=None, batch_size=1, num_workers=1, train_steps=None, val_steps=None, test_steps=None, shuffle=True)[source]¶ Use this trial with the given data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 test iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> targets = torch.rand(10, 1) >>> test_data = torch.rand(10, 1) >>> trial = Trial(None).with_data(x_train=data, y_train=targets, x_test=test_data) >>> trial.for_test_steps(10).run(1)
Parameters: - x_train (torch.Tensor) – The training data to use
- y_train (torch.Tensor) – The training targets to use
- x_val (torch.Tensor) – The validation data to use
- y_val (torch.Tensor) – The validation targets to use
- x_test (torch.Tensor) – The test data to use
- batch_size (int) – Batch size to use in mini-batching
- num_workers (int) – Number of workers to use for data loading and batching
- train_steps (int) – Number of steps for each training pass
- val_steps (int) – Number of steps for each validation pass
- test_steps (int) – Number of steps for each test pass
- shuffle (bool) – If True, shuffle training and validation data.
Returns: self
Return type:
-
for_inf_train_steps
()[source]¶ Use this trial with an infinite number of training steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs training data until stopped >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> trial = Trial(None).with_train_generator(trainloader).for_inf_train_steps() >>> trial.run(1)
Returns: self Return type: Trial
-
for_inf_val_steps
()[source]¶ Use this trial with an infinite number of validation steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs validation data until stopped >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> valloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_val_generator(valloader).for_inf_val_steps() >>> trial.run(1)
Returns: self Return type: Trial
-
for_inf_test_steps
()[source]¶ Use this trial with an infinite number of test steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs test data until stopped >>> from torchbearer import Trial >>> test_data = torch.rand(1000, 10) >>> trial = Trial(None).with_test_data(test_data).for_inf_test_steps() >>> trial.run(1)
Returns: self Return type: Trial
-
for_inf_steps
(train=True, val=True, test=True)[source]¶ Use this trail with infinite steps. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs training and test data until stopped >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> valloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_train_generator(trainloader).for_inf_steps(valloader) >>> trial.with_inf_test_loader(True, False, True).run(1)
Parameters: - train (bool) – Use an infinite number of training steps
- val (bool) – Use an infinite number of validation steps
- test (bool) – Use an infinite number of test steps
Returns: self
Return type:
-
with_inf_train_loader
()[source]¶ Use this trial with a training iterator that refreshes when it finishes instead of each epoch. This allows for setting training steps less than the size of the generator and model will still be trained on all training samples if enough “epochs” are run.
Example:
# Simple trial that runs 10 epochs of 100 iterations of a training generator without reshuffling until all data has been seen >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> trial = Trial(None).with_train_generator(trainloader).with_inf_train_loader() >>> trial.run(10)
Returns: self: Return type: Trial
-
with_loader
(batch_loader)[source]¶ Use this trial with custom batch loader. Usually calls next on state[torchbearer.ITERATOR] and populates state[torchbearer.X] and state[torchbearer.Y_TRUE]
Example:
# Simple trial that runs with a custom loader function that populates X and Y_TRUE in state with random data >>> from torchbearer import Trial >>> def custom_loader(state): ... state[X], state[Y_TRUE] = torch.rand(5, 5), torch.rand(5, 5) >>> trial = Trial(None).with_loader(custom_loader) >>> trial.run(10)
Parameters: batch_loader (function) – Function of state that extracts data from data loader (stored under torchbearer.ITERATOR), stores it in state and sends it to the correct device Returns: self: Return type: Trial
-
with_closure
(closure)[source]¶ Use this trial with custom closure
Example:
# Simple trial that runs with a custom closure >>> from torchbearer import Trial >>> def custom_closure(state): ... print(state[torchbearer.BATCH]) >>> trial = Trial(None).with_closure(custom_closure).for_steps(3) >>> _ = trial.run(1) 0 1 2
Parameters: closure (function) – Function of state that defines the custom closure Returns: self: Return type: Trial
-
run
(epochs=1, verbose=-1)[source]¶ Run this trial for the given number of epochs, starting from the last trained epoch.
Example:
# Simple trial that runs with a custom closure >>> from torchbearer import Trial >>> trial = Trial(None).for_steps(100) >>> _ = trial.run(1)
Parameters: - epochs (int, optional) – The number of epochs to run for
- verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
- State Requirements:
torchbearer.state.MODEL
: Model should be callable and not none, set on Trial init
Returns: The model history (list of tuple of steps summary and epoch metric dicts) Return type: list
-
evaluate
(verbose=-1, data_key=None)[source]¶ Evaluate this trial on the validation data.
Example:
# Simple trial to evaluate on both validation and test data >>> from torchbearer import Trial >>> test_data = torch.rand(5, 5) >>> val_data = torch.rand(5, 5) >>> t = Trial(None).with_val_data(val_data).with_test_data(test_data) >>> t.evaluate(data_key=torchbearer.VALIDATION_DATA).evaluate(data_key=torchbearer.TEST_DATA)
Parameters: Returns: The final metric values
Return type: dict
-
predict
(verbose=-1, data_key=None)[source]¶ Determine predictions for this trial on the test data.
Example:
# Simple trial to predict on some validation and test data >>> from torchbearer import Trial >>> val_data = torch.rand(5, 5) >>> test_data = torch.rand(5, 5) >>> t = Trial(None).with_test_data(test_data) >>> test_predictions = t.predict(data_key=torchbearer.TEST_DATA)
Parameters: Returns: Model outputs as a list
Return type: list
-
replay
(callbacks=None, verbose=2, one_batch=False)[source]¶ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.
Example:
>>> from torchbearer import Trial >>> state = torch.load('some_state.pt') >>> t = Trial(None).load_state_dict(state) >>> t.replay()
Parameters: - callbacks (list) – List of callbacks to be run during the replay
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
- one_batch (bool) – If True, only one batch per epoch is replayed. If False, all batches are replayed
Returns: self
Return type:
-
train
()[source]¶ Set model and metrics to training mode.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).train()
Returns: self Return type: Trial
-
eval
()[source]¶ Set model and metrics to evaluation mode
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).eval()
Returns: self Return type: Trial
-
to
(*args, **kwargs)[source]¶ Moves and/or casts the parameters and buffers.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).to('cuda:1')
Parameters: - args – See: torch.nn.Module.to
- kwargs –
See: torch.nn.Module.to
Returns: self
Return type:
-
cuda
(device=None)[source]¶ Moves all model parameters and buffers to the GPU.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).cuda()
Parameters: device (int) – if specified, all parameters will be copied to that device Returns: self Return type: Trial
-
cpu
()[source]¶ Moves all model parameters and buffers to the CPU.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).cpu()
Returns: self Return type: Trial
-
state_dict
(**kwargs)[source]¶ Get a dict containing the model and optimizer states, as well as the model history.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None) >>> state = t.state_dict() # State dict that can now be saved with torch.save
Parameters: kwargs – See: torch.nn.Module.state_dict Returns: A dict containing parameters and persistent buffers. Return type: dict
-
load_state_dict
(state_dict, resume=True, **kwargs)[source]¶ Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None) >>> state = torch.load('some_state.pt') >>> t.load_state_dict(state)
Parameters: - state_dict (dict) – The state dict to reload
- resume (bool) – If True, resume from the given state. Else, just load in the model weights.
- kwargs – See: torch.nn.Module.load_state_dict
Returns: self
Return type:
Batch Loaders¶
-
torchbearer.trial.
load_batch_infinite
(loader)[source]¶ Wraps a batch loader and refreshes the iterator once it has been completed.
Parameters: loader – batch loader to wrap
-
torchbearer.trial.
load_batch_none
(state)[source]¶ Load a none (none, none) tuple mini-batch into state
Parameters: state (dict) – The current state dict of the Trial
.
Misc¶
-
torchbearer.trial.
deep_to
(batch, device, dtype)[source]¶ Static method to call
to()
on tensors, tuples or dicts. All items will havedeep_to()
calledExample:
>>> import torch >>> from torchbearer import deep_to >>> example_dict = {'a': torch.ones(5)*2.1, 'b': torch.ones(1)*5.9} >>> deep_to(example_dict, device='cpu', dtype=torch.int) {'a': tensor([2, 2, 2, 2, 2], dtype=torch.int32), 'b': tensor([5], dtype=torch.int32)}
Parameters: - batch (tuple / list / torch.Tensor / dict) – The mini-batch which requires a
to()
call - device (torch.device) – The desired device of the batch
- dtype (torch.dtype) – The desired datatype of the batch
Returns: The moved or casted batch
Return type: tuple / list / torch.Tensor
- batch (tuple / list / torch.Tensor / dict) – The mini-batch which requires a
State¶
The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced
during model fitting. This module defines classes for interacting with state and all of the built in state keys used
throughout torchbearer. The state_key()
function can be used to create custom state keys for use in callbacks or
metrics.
Example:
>>> from torchbearer import state_key
>>> MY_KEY = state_key('my_test_key')
State¶
-
class
torchbearer.state.
State
[source]¶ State dictionary that behaves like a python dict but accepts StateKeys
-
data
¶
-
-
class
torchbearer.state.
StateKey
(key)[source]¶ StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.
Parameters: key (str) – Base key
Key List¶
-
torchbearer.state.
BACKWARD_ARGS
= backward_args The optional arguments which should be passed to the backward call
-
torchbearer.state.
BATCH
= t The current batch number
-
torchbearer.state.
CALLBACK_LIST
= callback_list The
CallbackList
object which is called by the Trial
-
torchbearer.state.
CRITERION
= criterion The criterion to use when model fitting
-
torchbearer.state.
DATA
= data The string name of the current data
-
torchbearer.state.
DATA_TYPE
= dtype The data type of tensors in use by the model, match this to avoid type issues
-
torchbearer.state.
DEVICE
= device The device currently in use by the
Trial
and PyTorch model
-
torchbearer.state.
EPOCH
= epoch The current epoch number
-
torchbearer.state.
FINAL_PREDICTIONS
= final_predictions The key which maps to the predictions over the dataset when calling predict
-
torchbearer.state.
GENERATOR
= generator The current data generator (DataLoader)
-
torchbearer.state.
HISTORY
= history The history list of the Trial instance
-
torchbearer.state.
INF_TRAIN_LOADING
= inf_train_loading Flag for refreshing of training iterator when finished instead of each epoch
-
torchbearer.state.
INPUT
= x The current batch of inputs
-
torchbearer.state.
ITERATOR
= iterator The current iterator
-
torchbearer.state.
LOADER
= loader The batch loader which handles formatting data from each batch
-
torchbearer.state.
LOSS
= loss The current value for the loss
-
torchbearer.state.
MAX_EPOCHS
= max_epochs The total number of epochs to run for
-
torchbearer.state.
METRICS
= metrics The metric dict from the current batch of data
-
torchbearer.state.
METRIC_LIST
= metric_list The list of metrics in use by the
Trial
-
torchbearer.state.
MIXUP_LAMBDA
= mixup_lambda The lambda coefficient of the linear combination of inputs
-
torchbearer.state.
MIXUP_PERMUTATION
= mixup_permutation The permutation of input indices for input mixup
-
torchbearer.state.
MODEL
= model The PyTorch module / model that will be trained
-
torchbearer.state.
OPTIMIZER
= optimizer The optimizer to use when model fitting
-
torchbearer.state.
PREDICTION
= y_pred The current batch of predictions
-
torchbearer.state.
SAMPLER
= sampler The sampler which loads data from the generator onto the correct device
-
torchbearer.state.
SELF
= self A self refrence to the Trial object for persistence etc.
-
torchbearer.state.
STEPS
= steps The current number of steps per epoch
-
torchbearer.state.
STOP_TRAINING
= stop_training A flag that can be set to true to stop the current fit call
-
torchbearer.state.
TARGET
= y_true The current batch of ground truth data
-
torchbearer.state.
TEST_DATA
= test_data The flag representing test data
-
torchbearer.state.
TEST_GENERATOR
= test_generator The test data generator in the Trial object
-
torchbearer.state.
TEST_STEPS
= test_steps The number of test steps to take
-
torchbearer.state.
TIMINGS
= timings The timings keys used by the timer callback
-
torchbearer.state.
TRAIN_DATA
= train_data The flag representing train data
-
torchbearer.state.
TRAIN_GENERATOR
= train_generator The train data generator in the Trial object
-
torchbearer.state.
TRAIN_STEPS
= train_steps The number of train steps to take
-
torchbearer.state.
VALIDATION_DATA
= validation_data The flag representing validation data
-
torchbearer.state.
VALIDATION_GENERATOR
= validation_generator The validation data generator in the Trial object
-
torchbearer.state.
VALIDATION_STEPS
= validation_steps The number of validation steps to take
-
torchbearer.state.
VERSION
= torchbearer_version The torchbearer version
-
torchbearer.state.
X
= x The current batch of inputs
-
torchbearer.state.
Y_PRED
= y_pred The current batch of predictions
-
torchbearer.state.
Y_TRUE
= y_true The current batch of ground truth data
Utilities¶
-
class
torchbearer.cv_utils.
DatasetValidationSplitter
(dataset_len, split_fraction, shuffle_seed=None)[source]¶ Generates training and validation split indicies for a given dataset length and creates training and validation datasets using these splits
Parameters: - dataset_len – The length of the dataset to be split into training and validation
- split_fraction – The fraction of the whole dataset to be used for validation
- shuffle_seed – Optional random seed for the shuffling process
-
class
torchbearer.cv_utils.
SubsetDataset
(dataset, ids)[source]¶ Dataset that consists of a subset of a previous dataset
Parameters: - dataset (torch.utils.data.Dataset) – Complete dataset
- ids (list) – List of subset IDs
-
torchbearer.cv_utils.
get_train_valid_sets
(x, y, validation_data, validation_split, shuffle=True)[source]¶ Generate validation and training datasets from whole dataset tensors
Parameters: - x (torch.Tensor) – Data tensor for dataset
- y (torch.Tensor) – Label tensor for dataset
- validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
- validation_split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation datasets
-
torchbearer.cv_utils.
train_valid_splitter
(x, y, split, shuffle=True)[source]¶ Generate training and validation tensors from whole dataset data and label tensors
Parameters: - x (torch.Tensor) – Data tensor for whole dataset
- y (torch.Tensor) – Label tensor for whole dataset
- split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation tensors (training data, training labels, validation data, validation labels)
-
torchbearer.bases.
base_closure
(x, model, y_pred, y_true, crit, loss, opt)[source]¶ Function to create a standard pytorch closure using objects taken from state under the given keys.
Parameters: - x – State key under which the input data is stored
- model – State key under which the pytorch model is stored
- y_pred – State key under which the predictions will be stored
- y_true – State key under which the targets are stored
- crit – State key under which the criterion function is stored (function of state or (y_pred, y_true))
- loss – State key under which the loss will be stored
- opt – State key under which the optimsiser is stored
Returns: Standard closure function
Return type: function