torchbearer¶
Model¶
-
class
torchbearer.torchbearer.
Model
(model, optimizer, criterion=None, metrics=[])[source]¶ Create torchbearermodel which wraps a base torchmodel and provides a training environment surrounding it
Parameters: - model (torch.nn.Module) – The base pytorch model
- optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
- criterion (function or None) – The final loss criterion that provides a loss value to the optimizer
- metrics (list) – Additional metrics for display and use within callbacks
-
cpu
()[source]¶ Moves all model parameters and buffers to the CPU.
Returns: Self torchbearermodel Return type: Model
-
cuda
(device=None)[source]¶ Moves all model parameters and buffers to the GPU.
Parameters: device (int, optional) – if specified, all parameters will be copied to that device Returns: Self torchbearermodel Return type: Model
-
evaluate
(x=None, y=None, batch_size=32, verbose=2, steps=None, pass_state=False)[source]¶ Perform an evaluation loop on given data and label tensors to evaluate metrics
Parameters: - x (torch.Tensor) – The input data tensor
- y (torch.Tensor) – The target labels for data tensor x
- batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The dictionary containing final metrics
Return type: dict[str,any]
-
evaluate_generator
(generator, verbose=2, steps=None, pass_state=False)[source]¶ Perform an evaluation loop on given data generator to evaluate metrics
Parameters: - generator (DataLoader) – The evaluation data generator (usually a pytorch DataLoader)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The dictionary containing final metrics
Return type: dict[str,any]
-
fit
(x, y, batch_size=None, epochs=1, verbose=2, callbacks=[], validation_split=None, validation_data=None, shuffle=True, initial_epoch=0, steps_per_epoch=None, validation_steps=None, workers=1, pass_state=False)[source]¶ Perform fitting of a model to given data and label tensors
Parameters: - x (torch.Tensor) – The input data tensor
- y (torch.Tensor) – The target labels for data tensor x
- batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
- epochs (int) – The number of training epochs to be run (each sample from the dataset is viewed exactly once)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
- callbacks (list) – The list of torchbearer callbacks to be called during training and validation
- validation_split (float) – Fraction of the training dataset to be set aside for validation testing
- validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data tensor
- shuffle (bool) – If True mini-batches of training/validation data are randomly selected, if False mini-batches samples are selected in order defined by dataset
- initial_epoch (int) – The integer value representing the first epoch - useful for continuing training after a number of epochs
- steps_per_epoch (int) – The number of training mini-batches to run per epoch
- validation_steps (int) – The number of validation mini-batches to run per epoch
- workers (int) – The number of cpu workers devoted to batch loading and aggregating
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The final state context dictionary
Return type: dict[str,any]
-
fit_generator
(generator, train_steps=None, epochs=1, verbose=2, callbacks=[], validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False)[source]¶ Perform fitting of a model to given data generator
Parameters: - generator (DataLoader) – The training data generator (usually a pytorch DataLoader)
- train_steps (int) – The number of training mini-batches to run per epoch
- epochs (int) – The number of training epochs to be run (each sample from the dataset is viewed exactly once)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
- callbacks (list) – The list of torchbearer callbacks to be called during training and validation
- validation_generator (DataLoader) – The validation data generator (usually a pytorch DataLoader)
- validation_steps (int) – The number of validation mini-batches to run per epoch
- initial_epoch (int) – The integer value representing the first epoch - useful for continuing training after a number of epochs
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The final state context dictionary
Return type: dict[str,any]
-
load_state_dict
(state_dict, **kwargs)[source]¶ Copies parameters and buffers from
state_dict()
into this module and its descendants.Parameters: - state_dict (dict) – A dict containing parameters and persistent buffers.
- kwargs – See: torch.nn.Module.load_state_dict
-
predict
(x=None, batch_size=32, verbose=2, steps=None, pass_state=False)[source]¶ Perform a prediction loop on given data tensor to predict labels
Parameters: - x (torch.Tensor) – The input data tensor
- batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: Tensor of final predicted labels
Return type: torch.Tensor
-
predict_generator
(generator, verbose=2, steps=None, pass_state=False)[source]¶ Perform a prediction loop on given data generator to predict labels
Parameters: - generator (DataLoader) – The prediction data generator (usually a pytorch DataLoader)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: Tensor of final predicted labels
Return type: torch.Tensor
-
state_dict
(**kwargs)[source]¶ Parameters: kwargs – See: torch.nn.Module.state_dict Returns: A dict containing parameters and persistent buffers. Return type: dict
-
to
(*args, **kwargs)[source]¶ Moves and/or casts the parameters and buffers.
Parameters: - args – See: torch.nn.Module.to
- kwargs –
See: torch.nn.Module.to
Returns: Self torchbearermodel
Return type:
Utilities¶
-
torchbearer.state.
state_key
(key)[source]¶ Computes and returns a non-conflicting key for the state dictionary when given a seed key
Parameters: key (String) – The seed key - basis for new state key Returns: New state key Return type: String
-
class
torchbearer.cv_utils.
DatasetValidationSplitter
(dataset_len, split_fraction, shuffle_seed=None)[source]¶
-
torchbearer.cv_utils.
get_train_valid_sets
(x, y, validation_data, validation_split, shuffle=True)[source]¶ Generate validation and training datasets from whole dataset tensors
Parameters: - x (torch.Tensor) – Data tensor for dataset
- y (torch.Tensor) – Label tensor for dataset
- validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
- validation_split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation datasets
Return type: tuple
-
torchbearer.cv_utils.
train_valid_splitter
(x, y, split, shuffle=True)[source]¶ Generate training and validation tensors from whole dataset data and label tensors
Parameters: - x (torch.Tensor) – Data tensor for whole dataset
- y (torch.Tensor) – Label tensor for whole dataset
- split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation tensors (training data, training labels, validation data, validation labels)
Return type: tuple