from __future__ import print_function
import torchbearer
from torchbearer.callbacks import Callback
[docs]class EarlyStopping(Callback):
"""Callback to stop training when a monitored quantity has stopped improving.
Args:
monitor (str): Name of quantity in metrics to be monitored
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no improvement.
patience (int): Number of epochs with no improvement after which training will be stopped.
verbose (int): Verbosity mode, will print stopping info if verbose > 0
mode (str): One of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped
decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode,
the direction is automatically inferred from the name of the monitored quantity.
State Requirements:
- :attr:`torchbearer.state.METRICS`: Metrics should be a dict containing the given monitor key as a minimum
"""
def __init__(self, monitor='val_loss',
min_delta=0, patience=0, verbose=0, mode='auto'):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.mode = mode
if self.mode not in ['min', 'max']:
if 'acc' in self.monitor:
self.mode = 'max'
else:
self.mode = 'min'
if self.mode == 'min':
self.min_delta *= -1
self.monitor_op = lambda x1, x2: x1 < x2
elif self.mode == 'max':
self.min_delta *= 1
self.monitor_op = lambda x1, x2: x1 > x2
self.wait = 0
self.stopped_epoch = 0
self.best = float('inf') if self.mode == 'min' else -float('inf')
[docs] def state_dict(self):
state_dict = {
'wait': self.wait,
'best': self.best,
'stopped_epoch': self.stopped_epoch
}
return state_dict
[docs] def load_state_dict(self, state_dict):
self.wait = state_dict['wait']
self.best = state_dict['best']
self.stopped_epoch = state_dict['stopped_epoch']
[docs] def on_end_epoch(self, state):
current = state[torchbearer.METRICS][self.monitor]
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = state[torchbearer.EPOCH]
state[torchbearer.STOP_TRAINING] = True
[docs] def on_end(self, state):
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))