import torchbearer
from torchbearer import Callback
import torch
from torch.distributions import Beta
from torchbearer.bases import cite
cutout = """
@article{devries2017improved,
title={Improved regularization of convolutional neural networks with Cutout},
author={DeVries, Terrance and Taylor, Graham W},
journal={arXiv preprint arXiv:1708.04552},
year={2017}
}
"""
random_erase = """
@article{zhong2017random,
title={Random erasing data augmentation},
author={Zhong, Zhun and Zheng, Liang and Kang, Guoliang and Li, Shaozi and Yang, Yi},
journal={arXiv preprint arXiv:1708.04896},
year={2017}
}
"""
cutmix = """
@article{yun2019cutmix,
title={Cutmix: Regularization strategy to train strong classifiers with localizable features},
author={Yun, Sangdoo and Han, Dongyoon and Oh, Seong Joon and Chun, Sanghyuk and Choe, Junsuk and Yoo, Youngjoon},
journal={arXiv preprint arXiv:1905.04899},
year={2019}
}
"""
[docs]
@cite(cutout)
class Cutout(Callback):
""" Cutout callback which randomly masks out patches of image data. Implementation a modified version of the code
found `here <https://github.com/uoguelph-mlrg/Cutout>`_.
Example: ::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import Cutout
# Example Trial which does Cutout regularisation
>>> cutout = Cutout(1, 10)
>>> trial = Trial(None, callbacks=[cutout], metrics=['acc'])
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
constant (float): Constant value for each square patch
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored
"""
def __init__(self, n_holes, length, constant=0.):
super(Cutout, self).__init__()
self.constant = constant
self.cutter = BatchCutout(n_holes, length, length)
[docs]
def on_sample(self, state):
super(Cutout, self).on_sample(state)
mask = self.cutter(state[torchbearer.X])
erase_locations = mask == 0
constant = torch.ones_like(state[torchbearer.X]) * self.constant
state[torchbearer.X][erase_locations] = constant[erase_locations]
[docs]
@cite(random_erase)
class RandomErase(Callback):
""" Random erase callback which replaces random patches of image data with random noise.
Implementation a modified version of the cutout code found
`here <https://github.com/uoguelph-mlrg/Cutout>`_.
Example: ::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import RandomErase
# Example Trial which does Cutout regularisation
>>> erase = RandomErase(1, 10)
>>> trial = Trial(None, callbacks=[erase], metrics=['acc'])
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored
"""
def __init__(self, n_holes, length):
super(RandomErase, self).__init__()
self.cutter = BatchCutout(n_holes, length, length)
[docs]
def on_sample(self, state):
super(RandomErase, self).on_sample(state)
mask = self.cutter(state[torchbearer.X])
erase_locations = mask == 0
random = torch.rand_like(state[torchbearer.X])
state[torchbearer.X][erase_locations] = random[erase_locations]
[docs]
@cite(cutmix)
class CutMix(Callback):
""" Cutmix callback which replaces a random patch of image data with the corresponding patch from another image.
This callback also converts labels to one hot before combining them according to the lambda parameters, sampled from
a beta distribution as is done in the paper.
Example: ::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import CutMix
# Example Trial which does CutMix regularisation
>>> cutmix = CutMix(1, classes=10)
>>> trial = Trial(None, callbacks=[cutmix], metrics=['acc'])
Args:
alpha (float): The alpha value for the beta distribution.
classes (int): The number of classes for conversion to one hot.
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored
- :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored
"""
def __init__(self, alpha, classes=-1, mixup_loss=False):
super(CutMix, self).__init__()
self.classes = classes
self.dist = Beta(torch.tensor([float(alpha)]), torch.tensor([float(alpha)]))
self.mixup_loss = mixup_loss
def _to_one_hot(self, target):
if target.dim() == 1:
target = target.unsqueeze(1)
one_hot = torch.zeros_like(target).repeat(1, self.classes)
one_hot.scatter_(1, target, 1)
return one_hot
return target
[docs]
def on_sample(self, state):
super(CutMix, self).on_sample(state)
lam = self.dist.sample().to(state[torchbearer.DEVICE])
length = (1 - lam).sqrt()
cutter = BatchCutout(1, (length * state[torchbearer.X].size(-1)).round().item(), (length * state[torchbearer.X].size(-2)).round().item())
mask = cutter(state[torchbearer.X])
erase_locations = mask == 0
permutation = torch.randperm(state[torchbearer.X].size(0))
if self.mixup_loss:
state[torchbearer.MIXUP_PERMUTATION] = permutation
state[torchbearer.MIXUP_LAMBDA] = lam
state[torchbearer.X][erase_locations] = state[torchbearer.X][permutation][erase_locations]
if self.mixup_loss:
state[torchbearer.TARGET] = (state[torchbearer.TARGET], state[torchbearer.TARGET][state[torchbearer.MIXUP_PERMUTATION]])
else:
target = self._to_one_hot(state[torchbearer.TARGET]).float()
state[torchbearer.TARGET] = lam * target + (1 - lam) * target[permutation]
[docs]
def on_sample_validation(self, state):
super(CutMix, self).on_sample_validation(state)
if not self.mixup_loss:
state[torchbearer.TARGET] = self._to_one_hot(state[torchbearer.TARGET]).float()
class BatchCutout(object):
"""Randomly mask out one or more patches from a batch of images.
Args:
n_holes (int): Number of patches to cut out of each image.
width (int): The width (in pixels) of each square patch.
height (int): The height (in pixels) of each square patch.
"""
def __init__(self, n_holes, width, height):
self.n_holes = n_holes
self.width = width
self.height = height
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (B, C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
b = img.size(0)
c = img.size(1)
h = img.size(-2)
w = img.size(-1)
mask = torch.ones((b, h, w), device=img.device)
for n in range(self.n_holes):
y = torch.randint(h, (b,)).long()
x = torch.randint(w, (b,)).long()
y1 = (y - self.height // 2).clamp(0, h).int()
y2 = (y + self.height // 2).clamp(0, h).int()
x1 = (x - self.width // 2).clamp(0, w).int()
x2 = (x + self.width // 2).clamp(0, w).int()
for batch in range(b):
mask[batch, y1[batch]: y2[batch], x1[batch]: x2[batch]] = 0
mask = mask.unsqueeze(1).repeat(1, c, 1, 1)
return mask