Source code for torchbearer.callbacks.terminate_on_nan

from __future__ import print_function
import torchbearer

from torchbearer.callbacks import Callback
from torchbearer.bases import get_metric

import math


[docs] class TerminateOnNaN(Callback): """Callback which montiors the given metric and halts training if its value is nan or inf. Example: :: >>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import TerminateOnNaN # Example Trial which terminates on a NaN, forced by a separate callback. Terminates on the 11th batch since the running loss only updates every 10 iterations. >>> term = TerminateOnNaN(monitor='running_loss') >>> @torchbearer.callbacks.on_criterion ... def force_terminate(state): ... if state[torchbearer.BATCH] == 5: ... state[torchbearer.LOSS] = state[torchbearer.LOSS] * torch.Tensor([float('NaN')]) >>> trial = Trial(None, callbacks=[term, force_terminate], metrics=['loss'], verbose=2).for_steps(30).run(1) Invalid running_loss, terminating Args: monitor (str): The name of the metric to monitor State Requirements: - :attr:`torchbearer.state.METRICS`: Metrics should be a dict containing at least the key `monitor` """ def __init__(self, monitor='running_loss'): super(TerminateOnNaN, self).__init__() self._monitor = monitor def _check(self, state): value = get_metric('TerminateOnNaN', state, self._monitor) if value is not None: if math.isnan(value) or math.isinf(value): print('Invalid ' + self._monitor + ', terminating') state[torchbearer.STOP_TRAINING] = True
[docs] def on_step_training(self, state): self._check(state)
[docs] def on_end_epoch(self, state): self._check(state)
[docs] def on_step_validation(self, state): self._check(state)