Source code for torchbearer.callbacks.mixup

import torch
import torch.nn.functional as F
from torch.distributions.beta import Beta

import torchbearer
from torchbearer import cite
from torchbearer.callbacks import Callback
from torchbearer.metrics import CategoricalAccuracy, AdvancedMetric, running_mean, mean, super


mixup= """
@inproceedings{zhang2018mixup,
  title={mixup: Beyond Empirical Risk Minimization},
  author={Hongyi Zhang and Moustapha Cisse and Yann N. Dauphin and David Lopez-Paz},
  booktitle={International Conference on Learning Representations},
  year={2018}
}
"""


@running_mean
@mean
class MixupAcc(AdvancedMetric):
    def __init__(self):
        super(MixupAcc, self).__init__('mixup_acc')
        self.cat_acc = CategoricalAccuracy().root

    def process_train(self, *args):
        super(MixupAcc, self).process_train(*args)
        state = args[0]
        
        target1, target2 = state[torchbearer.Y_TRUE]
        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target1
        acc1 = self.cat_acc.process(_state)

        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target2
        acc2 = self.cat_acc.process(_state)

        return acc1 * state[torchbearer.MIXUP_LAMBDA] + acc2 * (1-state[torchbearer.MIXUP_LAMBDA])

    def process_validate(self, *args):
        super(MixupAcc, self).process_validate(*args)

        return self.cat_acc.process(*args)

    def reset(self, state):
        self.cat_acc.reset(state)


[docs] @cite(mixup) class Mixup(Callback): """Perform mixup on the model inputs. Requires use of :meth:`MixupInputs.loss`, otherwise lambdas can be found in state under :attr:`.MIXUP_LAMBDA`. Model targets will be a tuple containing the original target and permuted target. .. note:: The accuracy metric for mixup is different on training to deal with the different targets, but for validation it is exactly the categorical accuracy, despite being called "val_mixup_acc" Example: :: >>> from torchbearer import Trial >>> from torchbearer.callbacks import Mixup # Example Trial which does Mixup regularisation >>> mixup = Mixup(0.9) >>> trial = Trial(None, criterion=Mixup.mixup_loss, callbacks=[mixup], metrics=['acc']) Args: lam (float): Mixup inputs by fraction lam. If RANDOM, choose lambda from Beta(alpha, alpha). Else, lambda=lam alpha (float): The alpha value to use in the beta distribution. """ RANDOM = -10.0 def __init__(self, alpha=1.0, lam=RANDOM): super(Mixup, self).__init__() self.alpha = alpha self.lam = lam if alpha > 0: self.distrib = Beta(self.alpha, self.alpha)
[docs] @staticmethod def mixup_loss(state): """The standard cross entropy loss formulated for mixup (weighted combination of `F.cross_entropy`). Args: state: The current :class:`Trial` state. """ input, target = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE] if state[torchbearer.DATA] is torchbearer.TRAIN_DATA: y1, y2 = target return F.cross_entropy(input, y1) * state[torchbearer.MIXUP_LAMBDA] + F.cross_entropy(input, y2) * (1-state[torchbearer.MIXUP_LAMBDA]) else: return F.cross_entropy(input, target)
[docs] def on_sample(self, state): if self.lam is Mixup.RANDOM: if self.alpha > 0: lam = self.distrib.sample() else: lam = 1.0 else: lam = self.lam state[torchbearer.MIXUP_LAMBDA] = lam state[torchbearer.MIXUP_PERMUTATION] = torch.randperm(state[torchbearer.X].size(0)) state[torchbearer.X] = state[torchbearer.X] * state[torchbearer.MIXUP_LAMBDA] + state[torchbearer.X][state[torchbearer.MIXUP_PERMUTATION], :] * (1-state[torchbearer.MIXUP_LAMBDA]) state[torchbearer.Y_TRUE] = (state[torchbearer.Y_TRUE], state[torchbearer.Y_TRUE][state[torchbearer.MIXUP_PERMUTATION]])
from torchbearer.metrics import default as d d.__loss_map__[Mixup.mixup_loss.__name__] = MixupAcc