"""
The distributions module is an extension of the `torch.distributions` package intended to facilitate implementations
required for specific variational approaches through the :class:`.SimpleDistribution` class. Generally, using a
:class:`torch.distributions.Distribution` object should be preferred over a :class:`SimpleDistribution`, for better
argument validation and more complete implementations. However, if you need to implement something new for a specific
variational approach, then a :class:`.SimpleDistribution` may be more forgiving. Furthermore, you may find it easier
to understand the function of the implementations here.
"""
import math
from numbers import Number
import torch
from torch.distributions import Distribution
from torch.distributions.utils import broadcast_all
from torchbearer import cite
steve = """
@article{squires2019a,
title={A Variational Autoencoder for Probabilistic Non-Negative Matrix Factorisation},
author={Steven Squires and Adam Prugel-Bennett and Mahesan Niranjan},
year={2019}
}
"""
[docs]class SimpleDistribution(Distribution):
"""Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob
function is not differentiable with respect to the distribution parameters or the given value, then this should be
mentioned in the documentation.
"""
has_rsample = True
def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size()):
super(SimpleDistribution, self).__init__(batch_shape, event_shape)
@property
def support(self):
return None
@property
def arg_constraints(self):
return None
[docs] def expand(self, batch_shape, _instance=None):
pass
@property
def mean(self):
return None
@property
def variance(self):
return None
[docs] def cdf(self, value):
pass
[docs] def icdf(self, value):
pass
[docs] def enumerate_support(self, expand=True):
pass
[docs] def entropy(self):
pass
[docs] def rsample(self, sample_shape=torch.Size()):
"""
Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.
"""
raise NotImplementedError
[docs] def log_prob(self, value):
"""Returns the log of the probability density/mass function evaluated at `value`.
Args:
value (torch.Tensor, Number): Value at which to evaluate log probabilty
"""
raise NotImplementedError
[docs]class SimpleNormal(SimpleDistribution):
"""The SimpleNormal class is a :class:`SimpleDistribution` which implements a straight forward Normal / Gaussian
distribution. This performs significantly fewer checks than `torch.distributions.Normal`, but should be sufficient
for the purpose of implementing a VAE.
Args:
mu (torch.Tensor, Number): The mean of the distribution, numbers will be cast to tensors
logvar (torch.Tensor, Number): The log variance of the distribution, numbers will be cast to tensors
"""
def __init__(self, mu, logvar):
self.mu, self.logvar = broadcast_all(mu, logvar)
if isinstance(mu, Number) and isinstance(logvar, Number):
batch_shape = torch.Size()
else:
batch_shape = self.mu.size()
super(SimpleNormal, self).__init__(batch_shape=batch_shape)
[docs] def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for a Normal distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per mean / variance given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
shape = self._extended_shape(sample_shape)
std = self.logvar.div(2).exp_()
eps = torch.normal(torch.zeros(shape, dtype=self.mu.dtype, device=self.mu.device),
torch.ones(shape, dtype=self.mu.dtype, device=self.mu.device))
return self.mu + std * eps
[docs] def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. Since the density of a
Gaussian is differentiable, this function is differentiable.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
var = self.logvar.exp()
return - ((value - self.mu) ** 2) / (2.0 * var) - (self.logvar / 2.0) - math.log(math.sqrt(2.0 * math.pi))
[docs]class SimpleExponential(SimpleDistribution):
"""The SimpleExponential class is a :class:`SimpleDistribution` which implements a straight forward Exponential
distribution with the given lograte. This performs significantly fewer checks than `torch.distributions.Exponential`
, but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed
in a stable fashion, without taking a logarithm.
Args:
lograte (torch.Tensor, Number): The natural log of the rate of the distribution, numbers will be cast to tensors
"""
def __init__(self, lograte):
self.lograte, = broadcast_all(lograte)
batch_shape = torch.Size() if isinstance(lograte, Number) else self.lograte.size()
super(SimpleExponential, self).__init__(batch_shape=batch_shape)
[docs] def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for an Exponential distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per lograte given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
shape = self._extended_shape(sample_shape)
return self.lograte.new(shape).exponential_() / self.lograte.exp()
[docs] def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. The log_prob for this
distribution is fully differentiable and has stable gradient since we use the lograte here.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
return self.lograte - self.lograte.exp() * value
[docs]@cite(steve)
class SimpleWeibull(SimpleDistribution):
"""The SimpleWeibull class is a :class:`SimpleDistribution` which implements a straight forward Weibull
distribution. This performs significantly fewer checks than `torch.distributions.Weibull`, but should be sufficient
for the purpose of implementing a VAE.
Args:
l (torch.Tensor, Number): The scale parameter of the distribution, numbers will be cast to tensors
k (torch.Tensor, Number): The shape parameter of the distribution, numbers will be cast to tensors
"""
def __init__(self, l, k):
self.l, self.k = broadcast_all(l, k)
self.const=1e-8
if isinstance(k, Number) and isinstance(l, Number):
batch_shape = torch.Size()
else:
batch_shape = self.k.size()
super(SimpleWeibull, self).__init__(batch_shape=batch_shape)
[docs] def rsample(self, sample_shape=torch.Size()):
"""Simple rsample for a Weibull distribution.
Args:
sample_shape (torch.Size, tuple): Shape of the sample (per k / lambda given)
Returns:
A reparameterized sample with gradient with respect to the distribution parameters
"""
shape = self._extended_shape(sample_shape)
eps = torch.rand(shape, dtype=self.k.dtype, device=self.k.device)
return self.l * torch.pow((-torch.log(eps)), (1/self.k))
[docs] def log_prob(self, value):
"""Calculates the log probability that the given value was drawn from this distribution. This function is differentiable
and its log probability is -inf for values less than 0.
Args:
value (torch.Tensor, Number): The sampled value
Returns:
The log probability that the given value was drawn from this distribution
"""
value = value if torch.is_tensor(value) else torch.tensor(value, dtype=torch.get_default_dtype())
lb=value.ge(torch.zeros(value.shape, dtype=self.k.dtype, device=self.k.device)).float()
return torch.log(lb) + torch.log(self.k/self.l) + (self.k - torch.ones(self.k.shape, dtype=self.k.dtype, device=self.k.device))*torch.log((lb*value+self.const)/self.l) - torch.pow(value/self.l, self.k)