Source code for torchbearer.callbacks.gradient_clipping

import torchbearer

from torchbearer.callbacks import Callback

import torch


[docs] class GradientNormClipping(Callback): """GradientNormClipping callback, which uses 'torch.nn.utils.clip_grad_norm_' to clip the gradient norms to the given value. If params is None they will be retrieved from state. Example: :: >>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import GradientNormClipping # Example Trial which clips all model gradients norms at 2 under the L1 norm. >>> model = torch.nn.Linear(1,1) >>> clip = GradientNormClipping(2, 1) >>> trial = Trial(model, callbacks=[clip], metrics=['acc']) Args: max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. params (Iterable[Tensor] or Tensor, optional): an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state State Requirements: - :attr:`torchbearer.state.MODEL`: Model should have the `parameters` method """ def __init__(self, max_norm, norm_type=2, params=None): super(GradientNormClipping, self).__init__() self.max_norm = max_norm self.norm_type = norm_type self.params = params
[docs] def on_start(self, state): """If params is None then retrieve from the model. Args: state (dict): The :class:`.Trial` state """ if self.params is None: self.params = filter(lambda p: p.requires_grad, state[torchbearer.MODEL].parameters())
[docs] def on_backward(self, state): """Between the backward pass (which computes the gradients) and the step call (which updates the parameters), clip the gradient. Args: state (dict): The :class:`.Trial` state """ torch.nn.utils.clip_grad_norm_(self.params, self.max_norm, norm_type=self.norm_type)
[docs] class GradientClipping(Callback): """GradientClipping callback, which uses 'torch.nn.utils.clip_grad_value_' to clip the gradients of the given parameters to the given value. If params is None they will be retrieved from state. Example: :: >>> import torch.nn >>> from torchbearer import Trial >>> from torchbearer.callbacks import GradientClipping # Example Trial which clips all model gradients at 2 under the L1 norm. >>> model = torch.nn.Linear(1,1) >>> clip = GradientNormClipping(2, 1) >>> trial = Trial(model, callbacks=[clip], metrics=['acc']) Args: clip_value (float or int): maximum allowed value of the gradients The gradients are clipped in the range [-clip_value, clip_value] params (Iterable[Tensor] or Tensor, optional): an iterable of Tensors or a single Tensor that will have gradients normalized, otherwise this is retrieved from state State Requirements: - :attr:`torchbearer.state.MODEL`: Model should have the `parameters` method """ def __init__(self, clip_value, params=None): super(GradientClipping, self).__init__() self.clip_value = clip_value self.params = list(params) if params is not None else None
[docs] def on_start(self, state): """If params is None then retrieve from the model. Args: state (dict): The :class:`.Trial` state """ if self.params is None: self.params = list(filter(lambda p: p.requires_grad, state[torchbearer.MODEL].parameters()))
[docs] def on_backward(self, state): """Between the backward pass (which computes the gradients) and the step call (which updates the parameters), clip the gradient. Args: state (dict): The :class:`.Trial` state """ torch.nn.utils.clip_grad_value_(self.params, self.clip_value)