Source code for torchbearer.callbacks.checkpointers

import torchbearer

import torch

from torchbearer.callbacks.callbacks import Callback
import os


class _Checkpointer(Callback):
    def __init__(self, fileformat, pickle_module=torch.serialization.pickle, pickle_protocol=torch.serialization.DEFAULT_PROTOCOL):
        super().__init__()
        self.fileformat = fileformat

        self.pickle_module = pickle_module
        self.pickle_protocol = pickle_protocol

        self.most_recent = None

        if fileformat.__contains__(os.sep) and not os.path.exists(os.path.dirname(fileformat)):
            os.makedirs(os.path.dirname(fileformat))

    def save_checkpoint(self, model_state, overwrite_most_recent=False):
        state = {}
        state.update(model_state)
        state.update(model_state[torchbearer.METRICS])

        string_state = {str(key): state[key] for key in state.keys()}
        filepath = self.fileformat.format(**string_state)

        if self.most_recent is not None and overwrite_most_recent:
            os.remove(self.most_recent)

        torch.save(model_state[torchbearer.SELF].state_dict(), filepath, pickle_module=self.pickle_module,
                   pickle_protocol=self.pickle_protocol)

        self.most_recent = filepath


[docs]def ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', save_best_only=False, mode='auto', period=1, min_delta=0): """Save the model after every epoch. `filepath` can contain named formatting options, which will be filled any values from state. For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}`, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The torch model will be saved to filename.pt and the torchbearermodel state will be saved to filename.torchbearer. :param filepath: Path to save the model file :type filepath: str :param monitor: Quantity to monitor :type monitor: str :param save_best_only: If `save_best_only=True`, the latest best model according to the quantity monitored will not be overwritten :type save_best_only: bool :param mode: One of {auto, min, max}. If `save_best_only=True`, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. :type mode: str :param period: Interval (number of epochs) between checkpoints :type period: int :param min_delta: If `save_best_only=True`, this is the minimum improvement required to trigger a save :type min_delta: float """ if save_best_only: check = Best(filepath, monitor, mode, period, min_delta) else: check = Interval(filepath, period) return check
[docs]class MostRecent(_Checkpointer): """Model checkpointer which saves the most recent model to a given filepath. :param filepath: Path to save the model file :type filepath: str :param pickle_module: The pickle module to use, default is 'torch.serialization.pickle' :param pickle_protocol: The pickle protocol to use, default is 'torch.serialization.DEFAULT_PROTOCOL' """ def __init__(self, filepath='model.{epoch:02d}-{val_loss:.2f}.pt', pickle_module=torch.serialization.pickle, pickle_protocol=torch.serialization.DEFAULT_PROTOCOL): super().__init__(filepath, pickle_module=pickle_module, pickle_protocol=pickle_protocol) self.filepath = filepath
[docs] def on_checkpoint(self, state): super().on_end_epoch(state) self.save_checkpoint(state, overwrite_most_recent=True)
[docs]class Best(_Checkpointer): """Model checkpointer which saves the best model according to the given configurations. :param filepath: Path to save the model file :type filepath: str :param monitor: Quantity to monitor :type monitor: str :param mode: One of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. :type mode: str :param period: Interval (number of epochs) between checkpoints :type period: int :param min_delta: This is the minimum improvement required to trigger a save :type min_delta: float :param pickle_module: The pickle module to use, default is 'torch.serialization.pickle' :param pickle_protocol: The pickle protocol to use, default is 'torch.serialization.DEFAULT_PROTOCOL' """ def __init__(self, filepath='model.{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='auto', period=1, min_delta=0, pickle_module=torch.serialization.pickle, pickle_protocol=torch.serialization.DEFAULT_PROTOCOL): super().__init__(filepath, pickle_module=pickle_module, pickle_protocol=pickle_protocol) self.min_delta = min_delta self.mode = mode self.monitor = monitor self.period = period self.epochs_since_last_save = 0 if self.mode not in ['min', 'max']: if 'acc' in self.monitor: self.mode = 'max' else: self.mode = 'min' if self.mode == 'min': self.min_delta *= -1 self.monitor_op = lambda x1, x2: (x1-self.min_delta) < x2 elif self.mode == 'max': self.min_delta *= 1 self.monitor_op = lambda x1, x2: (x1-self.min_delta) > x2 self.best = None
[docs] def state_dict(self): state_dict = super().state_dict() state_dict['epochs'] = self.epochs_since_last_save state_dict['best'] = self.best return state_dict
[docs] def load_state_dict(self, state_dict): super().load_state_dict(state_dict) self.epochs_since_last_save = state_dict['epochs'] self.best = state_dict['best'] return self
[docs] def on_start(self, state): if self.best is None: self.best = float('inf') if self.mode == 'min' else -float('inf')
[docs] def on_checkpoint(self, state): super().on_end_epoch(state) self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 current = state[torchbearer.METRICS][self.monitor] if self.monitor_op(current, self.best): self.best = current self.save_checkpoint(state, overwrite_most_recent=True)
[docs]class Interval(_Checkpointer): """Model checkpointer which which saves the model every 'period' epochs to the given filepath. :param filepath: Path to save the model file :type filepath: str :param period: Interval (number of epochs) between checkpoints :type period: int :param pickle_module: The pickle module to use, default is 'torch.serialization.pickle' :param pickle_protocol: The pickle protocol to use, default is 'torch.serialization.DEFAULT_PROTOCOL' """ def __init__(self, filepath='model.{epoch:02d}-{val_loss:.2f}.pt', period=1, pickle_module=torch.serialization.pickle, pickle_protocol=torch.serialization.DEFAULT_PROTOCOL): super().__init__(filepath, pickle_module=pickle_module, pickle_protocol=pickle_protocol) self.period = period self.epochs_since_last_save = 0
[docs] def state_dict(self): state_dict = super().state_dict() state_dict['epochs'] = self.epochs_since_last_save return state_dict
[docs] def load_state_dict(self, state_dict): super().load_state_dict(state_dict) self.epochs_since_last_save = state_dict['epochs'] return self
[docs] def on_checkpoint(self, state): super().on_end_epoch(state) self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 self.save_checkpoint(state)