Source code for torchbearer.torchbearer

import warnings

import torch
from torch.utils.data import DataLoader, TensorDataset

import torchbearer
from torchbearer import metrics as torchbearer_metrics
from torchbearer.callbacks.aggregate_predictions import AggregatePredictions
from torchbearer.callbacks.callbacks import CallbackList
from torchbearer.callbacks.printer import Tqdm


[docs]class Model: """ .. deprecated:: 0.2.0 Use :class:`.Trial` instead. Create torchbearermodel which wraps a base torchmodel and provides a training environment surrounding it :param model: The base pytorch model :type model: torch.nn.Module :param optimizer: The optimizer used for pytorch model weight updates :type optimizer: torch.optim.Optimizer :param criterion: The final loss criterion that provides a loss value to the optimizer :type criterion: function or None :param metrics: Additional metrics for display and use within callbacks :type metrics: list """ def __init__(self, model, optimizer, criterion=None, metrics=[]): super().__init__() warnings.warn( 'torchbearer.Model and all of its attributes are deprecated as of version 0.2.0. Use torchbearer.Trial instead', DeprecationWarning) warnings.warn( 'torchbearer.Model and all of its attributes are deprecated as of version 0.2.0. Use torchbearer.Trial instead', UserWarning) if criterion is None: criterion = lambda y_pred, y_true: torch.zeros(1, device=y_true.device) self.main_state = { torchbearer.MODEL: model, torchbearer.CRITERION: criterion, torchbearer.OPTIMIZER: optimizer, torchbearer.DEVICE: 'cpu', torchbearer.HISTORY: [], # To retain some compatability with new callbacks torchbearer.DATA_TYPE: torch.float32, torchbearer.METRIC_LIST: torchbearer_metrics.MetricList(metrics), torchbearer.SELF: self, torchbearer.CALLBACK_LIST: torchbearer.callbacks.CallbackList([]) }
[docs] def fit(self, x, y, batch_size=None, epochs=1, verbose=2, callbacks=[], validation_split=None, validation_data=None, shuffle=True, initial_epoch=0, steps_per_epoch=None, validation_steps=None, workers=1, pass_state=False): """ Perform fitting of a model to given data and label tensors :param x: The input data tensor :type x: torch.Tensor :param y: The target labels for data tensor x :type y: torch.Tensor :param batch_size: The mini-batch size (number of samples processed for a single weight update) :type batch_size: int :param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once) :type epochs: int :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress :type verbose: int :param callbacks: The list of torchbearer callbacks to be called during training and validation :type callbacks: list :param validation_split: Fraction of the training dataset to be set aside for validation testing :type validation_split: float :param validation_data: Optional validation data tensor :type validation_data: (torch.Tensor, torch.Tensor) :param shuffle: If True mini-batches of training/validation data are randomly selected, if False mini-batches samples are selected in order defined by dataset :type shuffle: bool :param initial_epoch: The integer value representing the first epoch - useful for continuing training after a number of epochs :type initial_epoch: int :param steps_per_epoch: The number of training mini-batches to run per epoch :type steps_per_epoch: int :param validation_steps: The number of validation mini-batches to run per epoch :type validation_steps: int :param workers: The number of cpu workers devoted to batch loading and aggregating :type workers: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: The final state context dictionary :rtype: dict[str,any] """ trainset, valset = torchbearer.cv_utils.get_train_valid_sets(x, y, validation_data, validation_split, shuffle=shuffle) trainloader = DataLoader(trainset, batch_size, shuffle=shuffle, num_workers=workers) if valset is not None: valloader = DataLoader(valset, batch_size, shuffle=shuffle, num_workers=workers) else: valloader = None return self.fit_generator(trainloader, train_steps=steps_per_epoch, epochs=epochs, verbose=verbose, callbacks=callbacks, validation_generator=valloader, validation_steps=validation_steps, initial_epoch=initial_epoch, pass_state=pass_state)
[docs] def fit_generator(self, generator, train_steps=None, epochs=1, verbose=2, callbacks=[], validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False): """ Perform fitting of a model to given data generator :param generator: The training data generator (usually a pytorch DataLoader) :type generator: DataLoader :param train_steps: The number of training mini-batches to run per epoch :type train_steps: int :param epochs: The number of training epochs to be run (each sample from the dataset is viewed exactly once) :type epochs: int :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress :type verbose: int :param callbacks: The list of torchbearer callbacks to be called during training and validation :type callbacks: list :param validation_generator: The validation data generator (usually a pytorch DataLoader) :type validation_generator: DataLoader :param validation_steps: The number of validation mini-batches to run per epoch :type validation_steps: int :param initial_epoch: The integer value representing the first epoch - useful for continuing training after a number of epochs :type initial_epoch: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: The final state context dictionary :rtype: dict[str,any] """ callbacks = Model._add_printer(callbacks, verbose) _callbacks = CallbackList(callbacks) # Get train and validation steps if validation_steps is None and validation_generator is not None: validation_steps = len(validation_generator) if train_steps is None: train_steps = len(generator) if generator is not None and train_steps > len(generator): train_steps = len(generator) if not isinstance(train_steps, int): train_steps = int(train_steps) warnings.warn("Number of training steps is not an int, converting to int") if not isinstance(epochs, int): if isinstance(epochs, float): epochs = int(epochs) warnings.warn("Number of epochs is a float, converting to int") else: warnings.warn("Number of epochs is neither float nor int, setting to 0") epochs = 0 # Init state state = { torchbearer.MAX_EPOCHS: epochs, torchbearer.TRAIN_STEPS: train_steps, torchbearer.STEPS: train_steps, torchbearer.BATCH: 0, torchbearer.TRAIN_GENERATOR: generator, torchbearer.STOP_TRAINING: False } state.update(self.main_state) state[torchbearer.CALLBACK_LIST] = state[torchbearer.CALLBACK_LIST].copy() state[torchbearer.CALLBACK_LIST].append(_callbacks) state[torchbearer.CALLBACK_LIST].on_start(state) for state[torchbearer.EPOCH] in range(initial_epoch, epochs): state[torchbearer.CALLBACK_LIST].on_start_epoch(state) if state[torchbearer.TRAIN_GENERATOR] is not None: state[torchbearer.GENERATOR] = state[torchbearer.TRAIN_GENERATOR] state[torchbearer.TRAIN_ITERATOR] = iter(state[torchbearer.TRAIN_GENERATOR]) state[torchbearer.ITERATOR] = state[torchbearer.TRAIN_ITERATOR] self.train() state[torchbearer.CALLBACK_LIST].on_start_training(state) state[torchbearer.METRIC_LIST].reset(state) state[torchbearer.METRICS] = {} for state[torchbearer.BATCH] in range(0, state[torchbearer.TRAIN_STEPS]): # Extract batch if state[torchbearer.TRAIN_GENERATOR] is None: # TODO: Replace with flag check self._load_batch_none(torchbearer.TRAIN_ITERATOR, state) else: self._load_batch_standard(torchbearer.TRAIN_ITERATOR, state) state[torchbearer.CALLBACK_LIST].on_sample(state) # Zero grads state[torchbearer.OPTIMIZER].zero_grad() # Forward pass if pass_state: state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X], state=state) else: state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X]) state[torchbearer.CALLBACK_LIST].on_forward(state) # Loss Calculation state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]) state[torchbearer.CALLBACK_LIST].on_criterion(state) state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state) # Backwards pass state[torchbearer.LOSS].backward() state[torchbearer.CALLBACK_LIST].on_backward(state) # Update parameters state[torchbearer.OPTIMIZER].step() state[torchbearer.CALLBACK_LIST].on_step_training(state) if state[torchbearer.STOP_TRAINING]: break state[torchbearer.METRICS].update(state[torchbearer.METRIC_LIST].process_final(state)) final_metrics = state[torchbearer.METRICS] state[torchbearer.CALLBACK_LIST].on_end_training(state) # Validate if validation_generator is not None or validation_steps is not None: state[torchbearer.VALIDATION_GENERATOR] = validation_generator state[torchbearer.GENERATOR] = validation_generator state[torchbearer.VALIDATION_STEPS] = validation_steps state[torchbearer.STEPS] = validation_steps self.eval() self._validate(state, state[torchbearer.CALLBACK_LIST], pass_state) final_metrics.update(state[torchbearer.METRICS]) state[torchbearer.METRICS] = final_metrics state[torchbearer.CALLBACK_LIST].on_end_epoch(state) if state[torchbearer.STOP_TRAINING]: break state[torchbearer.CALLBACK_LIST].on_end(state) return state
def _test_loop(self, state, callbacks, pass_state, batch_loader, num_steps=None): """ The generic testing loop used for validation, evaluation and prediction :param state: The current state context dictionary :type state: dict[str,any] :param callbacks: The list of torchbearer callbacks to be called during training and validation :type callbacks: CallbackList :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :param batch_loader: The batch loader to use :type batch_loader: function :param num_steps: The number of testing mini-batches to run :return: The state context dictionary :rtype: dict[str,any] """ with torch.no_grad(): state[torchbearer.CALLBACK_LIST] = callbacks state[torchbearer.METRIC_LIST].reset(state) state[torchbearer.METRICS] = {} if num_steps is None: num_steps = len(state[torchbearer.VALIDATION_GENERATOR]) if state[torchbearer.VALIDATION_GENERATOR] is not None and num_steps > len(state[torchbearer.VALIDATION_GENERATOR]): num_steps = len(state[torchbearer.VALIDATION_GENERATOR]) if not isinstance(num_steps, int): num_steps = int(num_steps) warnings.warn('Num test steps is not an int, converting to int.', Warning) state[torchbearer.VALIDATION_STEPS] = num_steps state[torchbearer.STEPS] = num_steps if state[torchbearer.VALIDATION_GENERATOR] is not None: state[torchbearer.VALIDATION_ITERATOR] = iter(state[torchbearer.VALIDATION_GENERATOR]) state[torchbearer.ITERATOR] = state[torchbearer.VALIDATION_ITERATOR] state[torchbearer.CALLBACK_LIST].on_start_validation(state) for state[torchbearer.BATCH] in range(state[torchbearer.VALIDATION_STEPS]): # Load batch batch_loader(torchbearer.VALIDATION_ITERATOR, state) state[torchbearer.CALLBACK_LIST].on_sample_validation(state) # Forward pass if pass_state: state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X], state=state) else: state[torchbearer.Y_PRED] = state[torchbearer.MODEL](state[torchbearer.X]) state[torchbearer.CALLBACK_LIST].on_forward_validation(state) # Loss and metrics if torchbearer.Y_TRUE in state: state[torchbearer.LOSS] = state[torchbearer.CRITERION](state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]) state[torchbearer.CALLBACK_LIST].on_criterion_validation(state) state[torchbearer.METRICS] = state[torchbearer.METRIC_LIST].process(state) state[torchbearer.CALLBACK_LIST].on_step_validation(state) if state[torchbearer.STOP_TRAINING]: break if torchbearer.Y_TRUE in state: state[torchbearer.METRICS].update(state[torchbearer.METRIC_LIST].process_final(state)) state[torchbearer.CALLBACK_LIST].on_end_validation(state) return state def _validate(self, state, _callbacks, pass_state): """ Perform a validation loop :param state: The current context state dictionary :param _callbacks: The list of torchbearer callbacks to be called during validation loop :type callbacks: CallbackList :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: None :rtype: None """ self._test_loop(state, _callbacks, pass_state, self._load_batch_standard, state[torchbearer.VALIDATION_STEPS])
[docs] def evaluate(self, x=None, y=None, batch_size=32, verbose=2, steps=None, pass_state=False): """ Perform an evaluation loop on given data and label tensors to evaluate metrics :param x: The input data tensor :type x: torch.Tensor :param y: The target labels for data tensor x :type y: torch.Tensor :param batch_size: The mini-batch size (number of samples processed for a single weight update) :type batch_size: int :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress :type verbose: int :param steps: The number of evaluation mini-batches to run :type steps: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: The dictionary containing final metrics :rtype: dict[str,any] """ trainset = DataLoader(TensorDataset(x, y), batch_size, steps) return self.evaluate_generator(trainset, verbose, pass_state=pass_state)
[docs] def evaluate_generator(self, generator, verbose=2, steps=None, pass_state=False): """ Perform an evaluation loop on given data generator to evaluate metrics :param generator: The evaluation data generator (usually a pytorch DataLoader) :type generator: DataLoader :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress :type verbose: int :param steps: The number of evaluation mini-batches to run :type steps: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: The dictionary containing final metrics :rtype: dict[str,any] """ state = {torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator} state.update(self.main_state) _callbacks = Model._add_printer([], verbose, validation_label_letter='e') if state[torchbearer.VALIDATION_GENERATOR] is None: batch_loader = self._load_batch_none else: batch_loader = self._load_batch_standard self._test_loop(state, CallbackList(_callbacks), pass_state, batch_loader, steps) return state[torchbearer.METRICS]
[docs] def predict(self, x=None, batch_size=32, verbose=2, steps=None, pass_state=False): """ Perform a prediction loop on given data tensor to predict labels :param x: The input data tensor :type x: torch.Tensor :param batch_size: The mini-batch size (number of samples processed for a single weight update) :type batch_size: int :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress :type verbose: int :param steps: The number of evaluation mini-batches to run :type steps: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: Tensor of final predicted labels :rtype: torch.Tensor """ pred_set = DataLoader(TensorDataset(x), batch_size, steps) return self.predict_generator(pred_set, verbose, pass_state=pass_state)
[docs] def predict_generator(self, generator, verbose=2, steps=None, pass_state=False): """Perform a prediction loop on given data generator to predict labels :param generator: The prediction data generator (usually a pytorch DataLoader) :type generator: DataLoader :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress :type verbose: int :param steps: The number of evaluation mini-batches to run :type steps: int :param pass_state: If True the state dictionary is passed to the torch model forward method, if False only the input data is passed :type pass_state: bool :return: Tensor of final predicted labels :rtype: torch.Tensor """ state = {torchbearer.EPOCH: 0, torchbearer.MAX_EPOCHS: 1, torchbearer.STOP_TRAINING: False, torchbearer.VALIDATION_GENERATOR: generator} state.update(self.main_state) _callbacks = Model._add_printer([AggregatePredictions()], verbose, validation_label_letter='p') self._test_loop(state, CallbackList(_callbacks), pass_state, self._load_batch_predict, steps) return state[torchbearer.FINAL_PREDICTIONS]
[docs] def train(self): """ Set model and metrics to training mode """ self.main_state[torchbearer.MODEL].train() self.main_state[torchbearer.METRIC_LIST].train()
[docs] def eval(self): """ Set model and metrics to evaluation mode """ self.main_state[torchbearer.MODEL].eval() self.main_state[torchbearer.METRIC_LIST].eval()
[docs] def to(self, *args, **kwargs): """ Moves and/or casts the parameters and buffers. :param args: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_ :param kwargs: See: `torch.nn.Module.to <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.to>`_ :return: Self torchbearermodel :rtype: Model """ self.main_state[torchbearer.MODEL].to(*args, **kwargs) for state in self.main_state[torchbearer.OPTIMIZER].state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(*args, **kwargs) self.main_state = Model._update_device_and_dtype_from_args(self.main_state, *args, **kwargs) return self
[docs] def cuda(self, device=None): """ Moves all model parameters and buffers to the GPU. :param device: if specified, all parameters will be copied to that device :type device: int, optional :return: Self torchbearermodel :rtype: Model """ if device is None: device = torch.cuda.current_device() return self.to('cuda:' + str(device))
[docs] def cpu(self): """ Moves all model parameters and buffers to the CPU. :return: Self torchbearermodel :rtype: Model """ return self.to('cpu')
[docs] def load_state_dict(self, state_dict, **kwargs): """ Copies parameters and buffers from :func:`state_dict` into this module and its descendants. :param state_dict: A dict containing parameters and persistent buffers. :type state_dict: dict :param kwargs: See: `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.load_state_dict>`_ """ self.main_state[torchbearer.MODEL].load_state_dict(state_dict[torchbearer.MODEL], **kwargs) self.main_state[torchbearer.OPTIMIZER].load_state_dict(state_dict[torchbearer.OPTIMIZER])
[docs] def state_dict(self, **kwargs): """ :param kwargs: See: `torch.nn.Module.state_dict <https://pytorch.org/docs/stable/nn.html?highlight=#torch.nn.Module.state_dict>`_ :return: A dict containing parameters and persistent buffers. :rtype: dict """ state_dict = { torchbearer.MODEL: self.main_state[torchbearer.MODEL].state_dict(**kwargs), torchbearer.OPTIMIZER: self.main_state[torchbearer.OPTIMIZER].state_dict() } return state_dict
@staticmethod def _add_printer(callbacks, verbose, validation_label_letter='v'): """Static method used to add the printer callback to the given list for the given verbose level :param callbacks: The list to add to :type callbacks: list :param verbose: 2, 1 or 0, Most -> Least verbose :type verbose: int :param validation_label_letter: Pass to Tqdm :type validation_label_letter: str :return: The updated list :rtype: list """ if verbose >= 2: return [Tqdm(validation_label_letter=validation_label_letter)] + callbacks elif verbose >= 1: return [Tqdm(validation_label_letter=validation_label_letter, on_epoch=True)] + callbacks else: return callbacks @staticmethod def _deep_to(batch, device, dtype): """ Static method to call :func:`to` on tensors or tuples. All items in tuple will have :func:_deep_to called :param batch: The mini-batch which requires a :func:`to` call :type batch: tuple, list, torch.Tensor :param device: The desired device of the batch :type device: torch.device :param dtype: The desired datatype of the batch :type dtype: torch.dtype :return: The moved or casted batch :rtype: tuple, list, torch.Tensor """ is_tuple = isinstance(batch, tuple) if isinstance(batch, list) or isinstance(batch, tuple): batch = list(batch) for i in range(len(batch)): batch[i] = Model._deep_to(batch[i], device, dtype) batch = tuple(batch) if is_tuple else batch elif isinstance(batch, dict): for key in batch: batch[key] = Model._deep_to(batch[key], device, dtype) else: if batch.dtype.is_floating_point: batch = batch.to(device, dtype) else: batch = batch.to(device) return batch @staticmethod def _load_batch_standard(iterator, state): """ Static method to load a standard (input data, target) tuple mini-batch from iterator into state :param iterator: Training or validation data iterator :type iterator: iterable :param state: The current state dict of the :class:`Model`. :type state: dict[str,any] """ state[torchbearer.X], state[torchbearer.Y_TRUE] = Model._deep_to(next(state[iterator]), state[torchbearer.DEVICE], state[torchbearer.DATA_TYPE]) @staticmethod def _load_batch_none(_, state): """Static method to load a none (none, none) tuple mini-batch into state :param state: The current state dict of the :class:`Model`. :type state: dict[str,any] """ state[torchbearer.X], state[torchbearer.Y_TRUE] = None, None @staticmethod def _load_batch_predict(iterator, state): """ Static method to load a prediction (input data, target) or (input data) mini-batch from iterator into state :param iterator: Training or validation data iterator :type iterator: iterable :param state: The current state dict of the :class:`Model`. :type state: dict[str,any] """ data = Model._deep_to(next(state[iterator]), state[torchbearer.DEVICE], state[torchbearer.DATA_TYPE]) if isinstance(data, list) or isinstance(data, tuple): state[torchbearer.X], state[torchbearer.Y_TRUE] = data else: state[torchbearer.X] = data @staticmethod def _update_device_and_dtype_from_args(main_state, *args, **kwargs): """ static method to update a main state dictionary with new data type and device values :param main_state: The main state to update :type main_state: dict[str,any] :param args: Arguments to the :func:`Model.to` function :param kwargs: Keyword arguments to the :func:`Model.to` function :return: Updated main state dictionary :rtype: dict[str,any] """ for key, _ in kwargs.items(): if key == 'device': main_state[torchbearer.DATA_TYPE] = kwargs['dtype'] elif 'device' in kwargs: main_state[torchbearer.DEVICE] = kwargs['device'] for arg in args: if isinstance(arg, torch.dtype): main_state[torchbearer.DATA_TYPE] = arg else: main_state[torchbearer.DEVICE] = arg return main_state