import torchbearer
from torchbearer import Callback
import torch
import torch.nn.functional as F
from torch.distributions import Beta
from torchbearer.bases import cite
bc = """
@inproceedings{tokozume2018between,
title={Between-class learning for image classification},
author={Tokozume, Yuji and Ushiku, Yoshitaka and Harada, Tatsuya},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={5486--5494},
year={2018}
}
"""
[docs]
@cite(bc)
class BCPlus(Callback):
"""BC+ callback which mixes images by treating them as waveforms. For standard BC, see :class:`.Mixup`.
This callback can optionally convert labels to one hot before combining them according to the lambda parameters,
sampled from a beta distribution, use alpha=1 to replicate the paper. Use with :meth:`BCPlus.bc_loss` or set
`mixup_loss = True` and use :meth:`.Mixup.mixup_loss`.
.. note::
This callback first sets all images to have zero mean. Consider adding an offset (e.g. 0.5) back before
visualising.
Example: ::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import BCPlus
# Example Trial which does BCPlus regularisation
>>> bcplus = BCPlus(classes=10)
>>> trial = Trial(None, criterion=BCPlus.bc_loss, callbacks=[bcplus], metrics=['acc'])
Args:
mixup_loss (bool): If True, the lambda and targets will be stored for use with the mixup loss function.
alpha (float): The alpha value for the beta distribution.
classes (int): The number of classes for conversion to one hot.
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored and correctly normalised
- :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored
"""
def __init__(self, mixup_loss=False, alpha=1, classes=-1):
super(BCPlus, self).__init__()
self.mixup_loss = mixup_loss
self.classes = classes
self.dist = Beta(torch.tensor([float(alpha)]), torch.tensor([float(alpha)]))
[docs]
@staticmethod
def bc_loss(state):
"""The KL divergence between the outputs of the model and the ratio labels. Model ouputs should be un-normalised
logits as this function performs a log_softmax.
Args:
state: The current :class:`Trial` state.
"""
prediction, target = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]
entropy = - (target[target.nonzero().split(1, dim=1)] * target[target.nonzero().split(1, dim=1)].log()).sum()
cross = - (target * F.log_softmax(prediction, dim=1)).sum()
return (cross - entropy) / prediction.size(0)
def _to_one_hot(self, target):
if target.dim() == 1:
target = target.unsqueeze(1)
one_hot = torch.zeros_like(target).repeat(1, self.classes)
one_hot.scatter_(1, target, 1)
return one_hot
return target.float()
[docs]
def on_sample(self, state):
super(BCPlus, self).on_sample(state)
lam = self.dist.sample().to(state[torchbearer.DEVICE])
permutation = torch.randperm(state[torchbearer.X].size(0))
batch1 = state[torchbearer.X]
batch1 = batch1 - batch1.view(batch1.size(0), -1).mean(1, keepdim=True).view(*tuple([batch1.size(0)] + [1] * (batch1.dim() - 1)))
g1 = batch1.view(batch1.size(0), -1).std(1, keepdim=True).view(*tuple([batch1.size(0)] + [1] * (batch1.dim() - 1)))
batch2 = batch1[permutation]
g2 = g1[permutation]
p = 1. / (1 + ((g1 / g2) * ((1 - lam) / lam)))
state[torchbearer.X] = (batch1 * p + batch2 * (1 - p)) / (p.pow(2) + (1 - p).pow(2)).sqrt()
if not self.mixup_loss:
target = self._to_one_hot(state[torchbearer.TARGET]).float()
state[torchbearer.Y_TRUE] = lam * target + (1 - lam) * target[permutation]
else:
state[torchbearer.MIXUP_LAMBDA] = lam
state[torchbearer.MIXUP_PERMUTATION] = permutation
state[torchbearer.Y_TRUE] = (state[torchbearer.Y_TRUE], state[torchbearer.Y_TRUE][state[torchbearer.MIXUP_PERMUTATION]])
[docs]
def on_sample_validation(self, state):
super(BCPlus, self).on_sample_validation(state)
if not self.mixup_loss:
state[torchbearer.TARGET] = self._to_one_hot(state[torchbearer.TARGET]).float()