Source code for torchbearer.variational.divergence

import functools

import torch
import torchbearer
from torchbearer import cite
import torchbearer.callbacks as callbacks

beta_vae = """
  title={beta-vae: Learning basic visual concepts with a constrained variational framework},
  author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander},

understanding_beta_vae = """
  title={Understanding disentangling in beta-vae},
  author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander},
  journal={arXiv preprint arXiv:1804.03599},

  author    = {Christian Bauckhage},
  title     = {Computing the Kullback-Leibler Divergence between two Generalized
               Gamma Distributions},
  journal   = {CoRR},
  volume    = {abs/1401.6853},
  year      = {2014}

[docs]class DivergenceBase(callbacks.Callback): """The :class:`DivergenceBase` class is an abstract base class which defines a series of useful methods for dealing with divergences. The keys dict given on init is used to map objects in state to kwargs in the compute function. Args: keys (dict): Dictionary which maps kwarg names to :class:`.StateKey` objects. When :meth:`compute` is called, the given kwargs are mapped to their associated values in state. state_key: If not None, the value outputted by :meth:`compute` is stored in state with the given key. """ def __init__(self, keys, state_key=None): self.keys = keys self.state_key = state_key self._post = lambda loss: loss self._reduce = lambda x: x.sum(1).mean(0) def store(state, val): state[state_key] = val.detach() self._store = store if state_key is not None else (lambda state, val: None)
[docs] def with_post_function(self, post_fcn): """Register the given post function, to be applied after to loss after reduction. Args: post_fcn: A function of loss which applies some operation (e.g. multiplying by beta) Returns: Divergence: self """ old_post = self._post self._post = lambda loss: post_fcn(old_post(loss)) return self
[docs] def compute(self, **kwargs): """Compute the loss with the given kwargs defined in the constructor. Args: kwargs: The bound kwargs, taken from state with the keys given in the constructor Returns: The calculated divergence as a two dimensional tensor (batch, distribution dimensions) """ raise NotImplementedError
[docs] def loss(self, state): kwargs = dict([(name, state[self.keys[name]]) for name in self.keys.keys()]) return self.compute(**kwargs)
[docs] def on_criterion(self, state): div = self._reduce(self.loss(state)) self._store(state, div) state[torchbearer.LOSS] = state[torchbearer.LOSS] + self._post(div)
[docs] def on_criterion_validation(self, state): div = self._reduce(self.loss(state)) self._store(state, div) state[torchbearer.LOSS] = state[torchbearer.LOSS] + self._post(div)
[docs] def with_reduction(self, reduction_fcn): """Override the reduction operation with the given function, use this if your divergence doesn't output a two dimensional tensor. Args: reduction_fcn: The function to be applied to the divergence output and return a single value Returns: Divergence: self """ self._reduce = reduction_fcn return self
[docs] def with_sum_mean_reduction(self): """Override the reduction function to take a sum over dimension one and a mean over dimension zero. (default) Returns: Divergence: self """ return self.with_reduction(lambda x: x.sum(1).mean(0))
[docs] def with_sum_sum_reduction(self): """Override the reduction function to take a sum over all dimensions. Returns: Divergence: self """ return self.with_reduction(lambda x: x.sum())
[docs] @cite(beta_vae) def with_beta(self, beta): """Multiply the divergence by the given beta, as introduced by beta-vae. Args: beta (float): The beta (> 1) to multiply by. Returns: Divergence: self """ def beta_div(loss): return beta * loss return self.with_post_function(beta_div)
[docs] @cite(understanding_beta_vae) def with_linear_capacity(self, min_c=0, max_c=25, steps=100000, gamma=1000): """Limit divergence by capacity, linearly increased from min_c to max_c for steps, as introduced in `Understanding disentangling in beta-VAE`. Args: min_c (float): Minimum capacity max_c (float): Maximum capacity steps (int): Number of steps to increase over gamma (float): Multiplicative gamma, usually a high number Returns: Divergence: self """ inc = steps / (max_c - min_c) d = {'c': min_c} old_callback = self.on_step_training @functools.wraps(old_callback) def step_c(state): if d['c'] < max_c: d['c'] += inc return old_callback(state) self.on_step_training = step_c def limit_div(loss): return gamma * (loss - d['c']).abs() return self.with_post_function(limit_div)
[docs]class SimpleNormalUnitNormalKL(DivergenceBase): """A KL divergence between a SimpleNormal (or similar) instance and a fixed unit normal (N[0, 1]) target. .. note:: The distribution object must have mu and logvar attributes Args: input_key: :class:`.StateKey` instance which will be mapped to the distribution object. state_key: If not None, the value outputted by :meth:`compute` is stored in state with the given key. """ def __init__(self, input_key, state_key=None): super(SimpleNormalUnitNormalKL, self).__init__({'input': input_key}, state_key=state_key)
[docs] def compute(self, input): mu, logvar =, input.logvar return 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)
[docs]class SimpleNormalSimpleNormalKL(DivergenceBase): """A KL divergence between two SimpleNormal (or similar) distributions. .. note:: The distribution objects must have mu and logvar attributes Args: input_key: :class:`.StateKey` instance which will be mapped to the input distribution object. target_key: :class:`.StateKey` instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by :meth:`compute` is stored in state with the given key. """ def __init__(self, input_key, target_key, state_key=None): super(SimpleNormalSimpleNormalKL, self).__init__({'input': input_key, 'target': target_key}, state_key=state_key)
[docs] def compute(self, input, target): mu_1, logvar_1 =, input.logvar mu_2, logvar_2 =, target.logvar return 0.5 * (logvar_1.exp() / logvar_2.exp() + (mu_2 - mu_1).pow(2) / logvar_2.exp() + logvar_2 - logvar_1 - 1)
[docs]@cite(weibullKL) class SimpleWeibullSimpleWeibullKL(DivergenceBase): """A KL divergence between two SimpleWeibull (or similar) distributions. .. note:: The distribution object must have lambda (scale) and k (shape) attributes Args: input_key: :class:`.StateKey` instance which will be mapped to the input distribution object. target_key: :class:`.StateKey` instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by :meth:`compute` is stored in state with the given key. """ def __init__(self, input_key, target_key, state_key=None): super(SimpleWeibullSimpleWeibullKL, self).__init__({'input': input_key, 'target': target_key}, state_key=state_key) self.gamma=0.5772
[docs] def compute(self, input, target): lambda_1, k_1 = input.l, input.k lambda_2, k_2 = target.l, target.k a = torch.log(k_1 / torch.pow(lambda_1, k_1)) b = torch.log(k_2 / torch.pow(lambda_2, k_2)) c = torch.mul((k_1 - k_2), (torch.log(lambda_1) - self.gamma / k_1)) n = k_2 / k_1 + 1 gammaf = torch.exp(torch.lgamma(n)) d = torch.mul(torch.pow(torch.div(lambda_1, lambda_2), k_2), gammaf) loss = torch.mean(a - b + c + d - 1) return loss
[docs]class SimpleExponentialSimpleExponentialKL(DivergenceBase): """A KL divergence between two SimpleExponential (or similar) distributions. .. note:: The distribution object must have lograte attribute Args: input_key: :class:`.StateKey` instance which will be mapped to the input distribution object. target_key: :class:`.StateKey` instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by :meth:`compute` is stored in state with the given key. """ def __init__(self, input_key, target_key, state_key=None): super(SimpleExponentialSimpleExponentialKL, self).__init__({'input': input_key, 'target': target_key}, state_key=state_key)
[docs] def compute(self, input, target): lograte_1 = input.lograte lograte_2 = target.lograte loss = lograte_1 - lograte_2 + lograte_2.exp()/lograte_1.exp() - 1 return loss