Source code for torchbearer.callbacks.early_stopping

from __future__ import print_function
import torchbearer

from torchbearer.callbacks import Callback
from .decorators import only_if
from torchbearer.bases import get_metric


[docs] class EarlyStopping(Callback): """Callback to stop training when a monitored quantity has stopped improving. Example: :: >>> from torchbearer import Trial >>> from torchbearer.callbacks import EarlyStopping # Example Trial which does early stopping if the validation accuracy drops below the max seen for 5 epochs in a row >>> stopping = EarlyStopping(monitor='val_acc', patience=5, mode='max') >>> trial = Trial(None, callbacks=[stopping], metrics=['acc']) Args: monitor (str): Name of quantity in metrics to be monitored min_delta (float): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. patience (int): Number of epochs with no improvement after which training will be stopped. mode (str): One of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. State Requirements: - :attr:`torchbearer.state.METRICS`: Metrics should be a dict containing the given monitor key as a minimum """ def __init__(self, monitor='val_loss', min_delta=0, patience=0, mode='auto', step_on_batch=False): super(EarlyStopping, self).__init__() self.monitor = monitor self.min_delta = min_delta self.patience = patience self.mode = mode self.step_on_batch = step_on_batch if self.mode not in ['min', 'max']: if 'acc' in self.monitor: self.mode = 'max' else: self.mode = 'min' if self.mode == 'min': self.min_delta *= -1 self.monitor_op = lambda x1, x2: x1 < x2 elif self.mode == 'max': self.min_delta *= 1 self.monitor_op = lambda x1, x2: x1 > x2 self.wait = 0 self.best = float('inf') if self.mode == 'min' else -float('inf')
[docs] def state_dict(self): state_dict = { 'wait': self.wait, 'best': self.best } return state_dict
[docs] def load_state_dict(self, state_dict): self.wait = state_dict['wait'] self.best = state_dict['best']
[docs] def step(self, state): current = get_metric('Early Stopping', state, self.monitor) if current is None: return if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: state[torchbearer.STOP_TRAINING] = True
@only_if(lambda self, _: self.step_on_batch) def on_step_training(self, state): self.step(state) @only_if(lambda self, _: not self.step_on_batch) def on_end_epoch(self, state): self.step(state)