torchbearer
Trial
- class torchbearer.Trial(model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2)[source]
The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.
Example:
>>> import torch >>> from torchbearer import Trial # Example trial that attempts to minimise the output of a linear layer. # Makes use of a callback to input the random data at each batch and a loss that is the absolute value of the # linear layer output. Runs for 10 iterations and a single epoch. >>> model = torch.nn.Linear(2,1) >>> optimiser = torch.optim.Adam(model.parameters(), lr=3e-4) >>> @torchbearer.callbacks.on_sample ... def initial_data(state): ... state[torchbearer.X] = torch.rand(1, 2)*10 >>> def minimise_output_loss(y_pred, y_true): ... return torch.abs(y_pred) >>> trial = Trial(model, optimiser, minimise_output_loss, ['loss'], [initial_data]).for_steps(10).run(1)
@article{2018torchbearer, title={Torchbearer: A Model Fitting Library for PyTorch}, author={Harris, Ethan and Painter, Matthew and Hare, Jonathon}, journal={arXiv preprint arXiv:1809.03363}, year={2018} }
- Parameters:
model (torch.nn.Module) – The base pytorch model
optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
criterion (func / None) – The final loss criterion that provides a loss value to the optimizer
metrics (list) – The list of
torchbearer.Metric
instances to process during fittingcallbacks (list) – The list of
torchbearer.Callback
instances to call during fittingverbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
- for_train_steps(steps)[source]
Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps is larger than dataset size then loader will be refreshed like if it was a new epoch. If steps is -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 100 training iterations, in this case optimising nothing >>> from torchbearer import Trial >>> trial = Trial(None).for_train_steps(100)
- Parameters:
steps (int) – The number of training steps per epoch to run.
- Returns:
self
- Return type:
- with_train_generator(generator, steps=None)[source]
Use this trial with the given train generator. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 100 training iterations on the MNIST dataset >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(MNIST('./data/', train=True)) >>> trial = Trial(None).with_train_generator(dataloader).for_steps(100).run(1)
- with_train_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]
Use this trial with the given train data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 training iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> targets = torch.rand(10, 1) >>> trial = Trial(None).with_val_data(data, targets).for_steps(10).run(1)
- Parameters:
x (torch.Tensor) – The train x data to use during calls to
run()
y (torch.Tensor) – The train labels to use during calls to
run()
batch_size (int) – The size of each batch to sample from the data
shuffle (bool) – If True, then data will be shuffled each epoch
num_workers (int) – Number of worker threads to use in the data loader
steps (int) – The number of steps per epoch to take when using this data
- Returns:
self
- Return type:
- for_val_steps(steps)[source]
Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 10 validation iterations on no data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).for_val_steps(10).run(1)
- Parameters:
steps (int) – The number of validation steps per epoch to run
- Returns:
self
- Return type:
- with_val_generator(generator, steps=None)[source]
Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 100 validation iterations on the MNIST dataset >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_val_generator(dataloader).for_steps(100).run(1)
- Parameters:
generator – The validation data generator to use during calls to
run()
andevaluate()
steps (int) – The number of steps per epoch to take when using this generator
- Returns:
self
- Return type:
- with_val_data(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]
Use this trial with the given validation data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 validation iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> targets = torch.rand(10, 1) >>> trial = Trial(None).with_val_data(data, targets).for_steps(10).run(1)
- Parameters:
x (torch.Tensor) – The validation x data to use during calls to
run()
andevaluate()
y (torch.Tensor) – The validation labels to use during calls to
run()
andevaluate()
batch_size (int) – The size of each batch to sample from the data
shuffle (bool) – If True, then data will be shuffled each epoch
num_workers (int) – Number of worker threads to use in the data loader
steps (int) – The number of steps per epoch to take when using this data
- Returns:
self
- Return type:
- for_test_steps(steps)[source]
Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 10 test iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
- with_test_generator(generator, steps=None)[source]
Use this trial with the given test generator. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 test iterations on no data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
- with_test_data(x, batch_size=1, num_workers=1, steps=None)[source]
Use this trial with the given test data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 test iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> trial = Trial(None).with_test_data(data).for_test_steps(10).run(1)
- Parameters:
x (torch.Tensor) – The test x data to use during calls to
predict()
batch_size (int) – The size of each batch to sample from the data
num_workers (int) – Number of worker threads to use in the data loader
steps (int) – The number of steps per epoch to take when using this data
- Returns:
self
- Return type:
- for_steps(train_steps=None, val_steps=None, test_steps=None)[source]
Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience. If steps larger than dataset size then loader will be refreshed like if it was a new epoch. If steps -1 then loader will be refreshed until stopped by STOP_TRAINING flag or similar.
Example:
# Simple trial that runs for 10 training, validation and test iterations on some random data >>> from torchbearer import Trial >>> train_data = torch.rand(10, 1) >>> val_data = torch.rand(10, 1) >>> test_data = torch.rand(10, 1) >>> trial = Trial(None).with_train_data(train_data).with_val_data(val_data).with_test_data(test_data) >>> trial.for_steps(10, 10, 10).run(1)
- with_generators(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]
Use this trial with the given generators. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 100 steps from a training and validation data generator >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> valloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_generators(trainloader, valloader, train_steps=100, val_steps=100).run(1)
- Parameters:
train_generator – The training data generator to use during calls to
run()
val_generator – The validation data generator to use during calls to
run()
andevaluate()
test_generator – The testing data generator to use during calls to
predict()
train_steps (int) – The number of steps per epoch to take when using the training generator
val_steps (int) – The number of steps per epoch to take when using the validation generator
test_steps (int) – The number of steps per epoch to take when using the testing generator
- Returns:
self
- Return type:
- with_data(x_train=None, y_train=None, x_val=None, y_val=None, x_test=None, batch_size=1, num_workers=1, train_steps=None, val_steps=None, test_steps=None, shuffle=True)[source]
Use this trial with the given data. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs for 10 test iterations on some random data >>> from torchbearer import Trial >>> data = torch.rand(10, 1) >>> targets = torch.rand(10, 1) >>> test_data = torch.rand(10, 1) >>> trial = Trial(None).with_data(x_train=data, y_train=targets, x_test=test_data) >>> trial.for_test_steps(10).run(1)
- Parameters:
x_train (torch.Tensor) – The training data to use
y_train (torch.Tensor) – The training targets to use
x_val (torch.Tensor) – The validation data to use
y_val (torch.Tensor) – The validation targets to use
x_test (torch.Tensor) – The test data to use
batch_size (int) – Batch size to use in mini-batching
num_workers (int) – Number of workers to use for data loading and batching
train_steps (int) – Number of steps for each training pass
val_steps (int) – Number of steps for each validation pass
test_steps (int) – Number of steps for each test pass
shuffle (bool) – If True, shuffle training and validation data.
- Returns:
self
- Return type:
- for_inf_train_steps()[source]
Use this trial with an infinite number of training steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs training data until stopped >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> trial = Trial(None).with_train_generator(trainloader).for_inf_train_steps() >>> trial.run(1)
- Returns:
self
- Return type:
- for_inf_val_steps()[source]
Use this trial with an infinite number of validation steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs validation data until stopped >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> valloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_val_generator(valloader).for_inf_val_steps() >>> trial.run(1)
- Returns:
self
- Return type:
- for_inf_test_steps()[source]
Use this trial with an infinite number of test steps (until stopped via STOP_TRAINING flag or similar). Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs test data until stopped >>> from torchbearer import Trial >>> test_data = torch.rand(1000, 10) >>> trial = Trial(None).with_test_data(test_data).for_inf_test_steps() >>> trial.run(1)
- Returns:
self
- Return type:
- for_inf_steps(train=True, val=True, test=True)[source]
Use this trail with infinite steps. Returns self so that methods can be chained for convenience.
Example:
# Simple trial that runs training and test data until stopped >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> valloader = DataLoader(MNIST('./data/', train=False)) >>> trial = Trial(None).with_train_generator(trainloader).for_inf_steps(valloader) >>> trial.with_inf_test_loader(True, False, True).run(1)
- Parameters:
train (bool) – Use an infinite number of training steps
val (bool) – Use an infinite number of validation steps
test (bool) – Use an infinite number of test steps
- Returns:
self
- Return type:
- with_inf_train_loader()[source]
Use this trial with a training iterator that refreshes when it finishes instead of each epoch. This allows for setting training steps less than the size of the generator and model will still be trained on all training samples if enough “epochs” are run.
Example:
# Simple trial that runs 10 epochs of 100 iterations of a training generator without reshuffling until all data has been seen >>> from torchbearer import Trial >>> from torchvision.datasets import MNIST >>> from torch.utils.data import DataLoader >>> trainloader = DataLoader(MNIST('./data/', train=True)) >>> trial = Trial(None).with_train_generator(trainloader).with_inf_train_loader() >>> trial.run(10)
- Returns:
self:
- Return type:
- with_loader(batch_loader)[source]
Use this trial with custom batch loader. Usually calls next on state[torchbearer.ITERATOR] and populates state[torchbearer.X] and state[torchbearer.Y_TRUE]
Example:
# Simple trial that runs with a custom loader function that populates X and Y_TRUE in state with random data >>> from torchbearer import Trial >>> def custom_loader(state): ... state[X], state[Y_TRUE] = torch.rand(5, 5), torch.rand(5, 5) >>> trial = Trial(None).with_loader(custom_loader) >>> trial.run(10)
- Parameters:
batch_loader (function) – Function of state that extracts data from data loader (stored under torchbearer.ITERATOR), stores it in state and sends it to the correct device
- Returns:
self:
- Return type:
- with_closure(closure)[source]
Use this trial with custom closure
Example:
# Simple trial that runs with a custom closure >>> from torchbearer import Trial >>> def custom_closure(state): ... print(state[torchbearer.BATCH]) >>> trial = Trial(None).with_closure(custom_closure).for_steps(3) >>> _ = trial.run(1) 0 1 2
- Parameters:
closure (function) – Function of state that defines the custom closure
- Returns:
self:
- Return type:
- run(epochs=1, verbose=-1)[source]
Run this trial for the given number of epochs, starting from the last trained epoch.
Example:
# Simple trial that runs with a custom closure >>> from torchbearer import Trial >>> trial = Trial(None).for_steps(100) >>> _ = trial.run(1)
- Parameters:
epochs (int, optional) – The number of epochs to run for
verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
- State Requirements:
torchbearer.state.MODEL
: Model should be callable and not none, set on Trial init
- Returns:
The model history (list of tuple of steps summary and epoch metric dicts)
- Return type:
list
- evaluate(verbose=-1, data_key=None)[source]
Evaluate this trial on the validation data.
Example:
# Simple trial to evaluate on both validation and test data >>> from torchbearer import Trial >>> test_data = torch.rand(5, 5) >>> val_data = torch.rand(5, 5) >>> t = Trial(None).with_val_data(val_data).with_test_data(test_data) >>> t.evaluate(data_key=torchbearer.VALIDATION_DATA).evaluate(data_key=torchbearer.TEST_DATA)
- predict(verbose=-1, data_key=None)[source]
Determine predictions for this trial on the test data.
Example:
# Simple trial to predict on some validation and test data >>> from torchbearer import Trial >>> val_data = torch.rand(5, 5) >>> test_data = torch.rand(5, 5) >>> t = Trial(None).with_test_data(test_data) >>> test_predictions = t.predict(data_key=torchbearer.TEST_DATA)
- replay(callbacks=None, verbose=2, one_batch=False)[source]
Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.
Example:
>>> from torchbearer import Trial >>> state = torch.load('some_state.pt') >>> t = Trial(None).load_state_dict(state) >>> t.replay()
- Parameters:
callbacks (list) – List of callbacks to be run during the replay
verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
one_batch (bool) – If True, only one batch per epoch is replayed. If False, all batches are replayed
- Returns:
self
- Return type:
- train()[source]
Set model and metrics to training mode.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).train()
- Returns:
self
- Return type:
- eval()[source]
Set model and metrics to evaluation mode
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).eval()
- Returns:
self
- Return type:
- to(*args, **kwargs)[source]
Moves and/or casts the parameters and buffers.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).to('cuda:1')
- Parameters:
args – See: torch.nn.Module.to
kwargs –
See: torch.nn.Module.to
- Returns:
self
- Return type:
- cuda(device=None)[source]
Moves all model parameters and buffers to the GPU.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).cuda()
- Parameters:
device (int) – if specified, all parameters will be copied to that device
- Returns:
self
- Return type:
- cpu()[source]
Moves all model parameters and buffers to the CPU.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None).cpu()
- Returns:
self
- Return type:
- state_dict(**kwargs)[source]
Get a dict containing the model and optimizer states, as well as the model history.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None) >>> state = t.state_dict() # State dict that can now be saved with torch.save
- Parameters:
kwargs – See: torch.nn.Module.state_dict
- Returns:
A dict containing parameters and persistent buffers.
- Return type:
dict
- load_state_dict(state_dict, resume=True, **kwargs)[source]
Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.
- Example: ::
>>> from torchbearer import Trial >>> t = Trial(None) >>> state = torch.load('some_state.pt') >>> t.load_state_dict(state)
- Parameters:
state_dict (dict) – The state dict to reload
resume (bool) – If True, resume from the given state. Else, just load in the model weights.
kwargs – See: torch.nn.Module.load_state_dict
- Returns:
self
- Return type:
Batch Loaders
- torchbearer.trial.load_batch_infinite(loader)[source]
Wraps a batch loader and refreshes the iterator once it has been completed.
- Parameters:
loader – batch loader to wrap
- torchbearer.trial.load_batch_none(state)[source]
Load a none (none, none) tuple mini-batch into state
- Parameters:
state (dict) – The current state dict of the
Trial
.
Misc
- torchbearer.trial.deep_to(batch, device, dtype)[source]
Static method to call
to()
on tensors, tuples or dicts. All items will havedeep_to()
calledExample:
>>> import torch >>> from torchbearer import deep_to >>> example_dict = {'a': torch.ones(5)*2.1, 'b': torch.ones(1)*5.9} >>> deep_to(example_dict, device='cpu', dtype=torch.int) {'a': tensor([2, 2, 2, 2, 2], dtype=torch.int32), 'b': tensor([5], dtype=torch.int32)}
- Parameters:
batch (tuple / list / torch.Tensor / dict) – The mini-batch which requires a
to()
calldevice (torch.device) – The desired device of the batch
dtype (torch.dtype) – The desired datatype of the batch
- Returns:
The moved or casted batch
- Return type:
tuple / list / torch.Tensor
State
The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced
during model fitting. This module defines classes for interacting with state and all of the built in state keys used
throughout torchbearer. The state_key()
function can be used to create custom state keys for use in callbacks or
metrics.
Example:
>>> from torchbearer import state_key
>>> MY_KEY = state_key('my_test_key')
State
- class torchbearer.state.State[source]
State dictionary that behaves like a python dict but accepts StateKeys
- property data
- class torchbearer.state.StateKey(key)[source]
StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.
- Parameters:
key (str) – Base key
Key List
- torchbearer.state.BACKWARD_ARGS = backward_args
The optional arguments which should be passed to the backward call
- torchbearer.state.BATCH = t
The current batch number
- torchbearer.state.CALLBACK_LIST = callback_list
The
CallbackList
object which is called by the Trial
- torchbearer.state.CRITERION = criterion
The criterion to use when model fitting
- torchbearer.state.DATA = data
The string name of the current data
- torchbearer.state.DATA_TYPE = dtype
The data type of tensors in use by the model, match this to avoid type issues
- torchbearer.state.DEVICE = device
The device currently in use by the
Trial
and PyTorch model
- torchbearer.state.EPOCH = epoch
The current epoch number
- torchbearer.state.FINAL_PREDICTIONS = final_predictions
The key which maps to the predictions over the dataset when calling predict
- torchbearer.state.GENERATOR = generator
The current data generator (DataLoader)
- torchbearer.state.HISTORY = history
The history list of the Trial instance
- torchbearer.state.INF_TRAIN_LOADING = inf_train_loading
Flag for refreshing of training iterator when finished instead of each epoch
- torchbearer.state.INPUT = x
The current batch of inputs
- torchbearer.state.ITERATOR = iterator
The current iterator
- torchbearer.state.LOADER = loader
The batch loader which handles formatting data from each batch
- torchbearer.state.LOSS = loss
The current value for the loss
- torchbearer.state.MAX_EPOCHS = max_epochs
The total number of epochs to run for
- torchbearer.state.METRICS = metrics
The metric dict from the current batch of data
- torchbearer.state.METRIC_LIST = metric_list
The list of metrics in use by the
Trial
- torchbearer.state.MIXUP_LAMBDA = mixup_lambda
The lambda coefficient of the linear combination of inputs
- torchbearer.state.MIXUP_PERMUTATION = mixup_permutation
The permutation of input indices for input mixup
- torchbearer.state.MODEL = model
The PyTorch module / model that will be trained
- torchbearer.state.OPTIMIZER = optimizer
The optimizer to use when model fitting
- torchbearer.state.PREDICTION = y_pred
The current batch of predictions
- torchbearer.state.SAMPLER = sampler
The sampler which loads data from the generator onto the correct device
- torchbearer.state.SELF = self
A self refrence to the Trial object for persistence etc.
- torchbearer.state.STEPS = steps
The current number of steps per epoch
- torchbearer.state.STOP_TRAINING = stop_training
A flag that can be set to true to stop the current fit call
- torchbearer.state.TARGET = y_true
The current batch of ground truth data
- torchbearer.state.TEST_DATA = test_data
The flag representing test data
- torchbearer.state.TEST_GENERATOR = test_generator
The test data generator in the Trial object
- torchbearer.state.TEST_STEPS = test_steps
The number of test steps to take
- torchbearer.state.TIMINGS = timings
The timings keys used by the timer callback
- torchbearer.state.TRAIN_DATA = train_data
The flag representing train data
- torchbearer.state.TRAIN_GENERATOR = train_generator
The train data generator in the Trial object
- torchbearer.state.TRAIN_STEPS = train_steps
The number of train steps to take
- torchbearer.state.VALIDATION_DATA = validation_data
The flag representing validation data
- torchbearer.state.VALIDATION_GENERATOR = validation_generator
The validation data generator in the Trial object
- torchbearer.state.VALIDATION_STEPS = validation_steps
The number of validation steps to take
- torchbearer.state.VERSION = torchbearer_version
The torchbearer version
- torchbearer.state.X = x
The current batch of inputs
- torchbearer.state.Y_PRED = y_pred
The current batch of predictions
- torchbearer.state.Y_TRUE = y_true
The current batch of ground truth data
Utilities
- class torchbearer.cv_utils.DatasetValidationSplitter(dataset_len, split_fraction, shuffle_seed=None)[source]
Generates training and validation split indicies for a given dataset length and creates training and validation datasets using these splits
- Parameters:
dataset_len – The length of the dataset to be split into training and validation
split_fraction – The fraction of the whole dataset to be used for validation
shuffle_seed – Optional random seed for the shuffling process
- class torchbearer.cv_utils.SubsetDataset(*args: Any, **kwargs: Any)[source]
Dataset that consists of a subset of a previous dataset
- Parameters:
dataset (torch.utils.data.Dataset) – Complete dataset
ids (list) – List of subset IDs
- torchbearer.cv_utils.get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True)[source]
Generate validation and training datasets from whole dataset tensors
- Parameters:
x (torch.Tensor) – Data tensor for dataset
y (torch.Tensor) – Label tensor for dataset
validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
validation_split (float) – Fraction of dataset to be used for validation
shuffle (bool) – If True randomize tensor order before splitting else do not randomize
- Returns:
Training and validation datasets
- torchbearer.cv_utils.train_valid_splitter(x, y, split, shuffle=True)[source]
Generate training and validation tensors from whole dataset data and label tensors
- Parameters:
x (torch.Tensor) – Data tensor for whole dataset
y (torch.Tensor) – Label tensor for whole dataset
split (float) – Fraction of dataset to be used for validation
shuffle (bool) – If True randomize tensor order before splitting else do not randomize
- Returns:
Training and validation tensors (training data, training labels, validation data, validation labels)
- torchbearer.bases.base_closure(x, model, y_pred, y_true, crit, loss, opt)[source]
Function to create a standard pytorch closure using objects taken from state under the given keys.
- Parameters:
x – State key under which the input data is stored
model – State key under which the pytorch model is stored
y_pred – State key under which the predictions will be stored
y_true – State key under which the targets are stored
crit – State key under which the criterion function is stored (function of state or (y_pred, y_true))
loss – State key under which the loss will be stored
opt – State key under which the optimsiser is stored
- Returns:
Standard closure function
- Return type:
function