torchbearer.variational

Distributions

The distributions module is an extension of the torch.distributions package intended to facilitate implementations required for specific variational approaches through the SimpleDistribution class. Generally, using a torch.distributions.Distribution object should be preferred over a SimpleDistribution, for better argument validation and more complete implementations. However, if you need to implement something new for a specific variational approach, then a SimpleDistribution may be more forgiving. Furthermore, you may find it easier to understand the function of the implementations here.

class torchbearer.variational.distributions.SimpleDistribution(batch_shape=<sphinx.ext.autodoc.importer._MockObject object>, event_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob function is not differentiable with respect to the distribution parameters or the given value, then this should be mentioned in the documentation.

arg_constraints
cdf(value)[source]
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]

Returns the log of the probability density/mass function evaluated at value. :param value: Value at which to evaluate log probabilty :type value: torch.Tensor, Number

mean
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.

support
variance
class torchbearer.variational.distributions.SimpleExponential(lograte)[source]

The SimpleExponential class is a SimpleDistribution which implements a straight forward Exponential distribution with the given lograte. This performs significantly fewer checks than torch.distributions.Exponential , but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed in a stable fashion, without taking a logarithm.

Parameters:lograte (torch.Tensor, Number) – The natural log of the rate of the distribution, numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. The log_prob for this distribution is fully differentiable and has stable gradient since we use the lograte here.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for an Exponential distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per lograte given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters
class torchbearer.variational.distributions.SimpleNormal(mu, logvar)[source]

The SimpleNormal class is a SimpleDistribution which implements a straight forward Normal / Gaussian distribution. This performs significantly fewer checks than torch.distributions.Normal, but should be sufficient for the purpose of implementing a VAE.

Parameters:
  • mu (torch.Tensor, Number) – The mean of the distribution, numbers will be cast to tensors
  • logvar (torch.Tensor, Number) – The log variance of the distribution, numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. Since the density of a Gaussian is differentiable, this function is differentiable.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for a Normal distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per mean / variance given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters
class torchbearer.variational.distributions.SimpleUniform(low, high)[source]

The SimpleUniform class is a SimpleDistribution which implements a straight forward Uniform distribution in the interval [low, high). This performs significantly fewer checks than torch.distributions.Uniform, but should be sufficient for the purpose of implementing a VAE.

Parameters:
  • low (torch.Tensor, Number) – The lower range of the distribution (inclusive), numbers will be cast to tensors
  • high (torch.Tensor, Number) – The upper range of the distribution (exclusive), numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. Since this distribution is uniform, the log probability is -log(high - low) for all values in the range [low, high) and -inf elsewhere. This function is therefore only piecewise differentiable.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for a Uniform distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per low / high given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters
class torchbearer.variational.distributions.SimpleWeibull(l, k)[source]

The SimpleWeibull class is a SimpleDistribution which implements a straight forward Weibull distribution. This performs significantly fewer checks than torch.distributions.Weibull, but should be sufficient for the purpose of implementing a VAE.

@article{squires2019a,
title={A Variational Autoencoder for Probabilistic Non-Negative Matrix Factorisation},
author={Steven Squires and Adam Prugel-Bennett and Mahesan Niranjan},
year={2019}
}
Parameters:
  • l (torch.Tensor, Number) – The scale parameter of the distribution, numbers will be cast to tensors
  • k (torch.Tensor, Number) – The shape parameter of the distribution, numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. This function is differentiable and its log probability is -inf for values less than 0.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for a Weibull distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per k / lambda given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters

Divergences

class torchbearer.variational.divergence.DivergenceBase(keys, state_key=None)[source]

The DivergenceBase class is an abstract base class which defines a series of useful methods for dealing with divergences. The keys dict given on init is used to map objects in state to kwargs in the compute function.

Parameters:
  • keys (dict) – Dictionary which maps kwarg names to StateKey objects. When compute() is called, the given kwargs are mapped to their associated values in state.
  • state_key – If not None, the value outputted by compute() is stored in state with the given key.
compute(**kwargs)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
loss(state)[source]
on_criterion(state)[source]

Perform some action with the given state as context after the criterion has been evaluated.

Parameters:state (dict) – The current state dict of the Trial.
on_criterion_validation(state)[source]

Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.

Parameters:state (dict) – The current state dict of the Trial.
with_beta(beta)[source]

Multiply the divergence by the given beta, as introduced by beta-vae.

@article{higgins2016beta,
  title={beta-vae: Learning basic visual concepts with a constrained variational framework},
  author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander},
  year={2016}
}
Parameters:beta (float) – The beta (> 1) to multiply by.
Returns:self
Return type:Divergence
with_linear_capacity(min_c=0, max_c=25, steps=100000, gamma=1000)[source]

Limit divergence by capacity, linearly increased from min_c to max_c for steps, as introduced in Understanding disentangling in beta-VAE.

@article{burgess2018understanding,
  title={Understanding disentangling in beta-vae},
  author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander},
  journal={arXiv preprint arXiv:1804.03599},
  year={2018}
}
Parameters:
  • min_c (float) – Minimum capacity
  • max_c (float) – Maximum capacity
  • steps (int) – Number of steps to increase over
  • gamma (float) – Multiplicative gamma, usually a high number
Returns:

self

Return type:

Divergence

with_post_function(post_fcn)[source]

Register the given post function, to be applied after to loss after reduction.

Parameters:post_fcn – A function of loss which applies some operation (e.g. multiplying by beta)
Returns:self
Return type:Divergence
with_reduction(reduction_fcn)[source]

Override the reduction operation with the given function, use this if your divergence doesn’t output a two dimensional tensor.

Parameters:reduction_fcn – The function to be applied to the divergence output and return a single value
Returns:self
Return type:Divergence
with_sum_mean_reduction()[source]

Override the reduction function to take a sum over dimension one and a mean over dimension zero. (default)

Returns:self
Return type:Divergence
with_sum_sum_reduction()[source]

Override the reduction function to take a sum over all dimensions.

Returns:self
Return type:Divergence
class torchbearer.variational.divergence.SimpleExponentialSimpleExponentialKL(input_key, target_key, state_key=None)[source]

A KL divergence between two SimpleExponential (or similar) distributions.

Note

The distribution object must have lograte attribute
Args:
input_key: StateKey instance which will be mapped to the input distribution object. target_key: StateKey instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by compute() is stored in state with the given key.
compute(input, target)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
class torchbearer.variational.divergence.SimpleNormalSimpleNormalKL(input_key, target_key, state_key=None)[source]

A KL divergence between two SimpleNormal (or similar) distributions.

Note

The distribution objects must have mu and logvar attributes

Parameters:
  • input_keyStateKey instance which will be mapped to the input distribution object.
  • target_keyStateKey instance which will be mapped to the target distribution object.
  • state_key – If not None, the value outputted by compute() is stored in state with the given key.
compute(input, target)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
class torchbearer.variational.divergence.SimpleNormalUnitNormalKL(input_key, state_key=None)[source]

A KL divergence between a SimpleNormal (or similar) instance and a fixed unit normal (N[0, 1]) target.

Note

The distribution object must have mu and logvar attributes

Parameters:
  • input_keyStateKey instance which will be mapped to the distribution object.
  • state_key – If not None, the value outputted by compute() is stored in state with the given key.
compute(input)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
class torchbearer.variational.divergence.SimpleWeibullSimpleWeibullKL(input_key, target_key, state_key=None)[source]

A KL divergence between two SimpleWeibull (or similar) distributions.

Note

The distribution object must have lambda (scale) and k (shape) attributes
@article{DBLP:journals/corr/Bauckhage14,
  author    = {Christian Bauckhage},
  title     = {Computing the Kullback-Leibler Divergence between two Generalized
               Gamma Distributions},
  journal   = {CoRR},
  volume    = {abs/1401.6853},
  year      = {2014}
}
Args:
input_key: StateKey instance which will be mapped to the input distribution object. target_key: StateKey instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by compute() is stored in state with the given key.
compute(input, target)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)

Auto-Encoding

class torchbearer.variational.auto_encoder.AutoEncoderBase(latent_dims)[source]
decode(sample, state=None)[source]

Decode the given latent space sample batch to images.

Parameters:
  • sample – The latent space samples
  • state – The trial state
Returns:

Decoded images

encode(x, state=None)[source]

Encode the given batch of images and return latent space sample for each.

Parameters:
  • x – Batch of images to encode
  • state – The trial state
Returns:

Encoded samples / tuple of samples for different spaces

forward(x, state=None)[source]

Encode then decode the inputs, returning the result. Also binds the target as the input images in state.

Parameters:
  • x – Model input batch
  • state – The trial state
Returns:

Auto-Encoded images

Datasets

class torchbearer.variational.datasets.CelebA(root, transform=None, target_transform=None)[source]
class torchbearer.variational.datasets.CelebA_HQ(root, as_npy=False, transform=None)[source]
static npy_loader(path)[source]
class torchbearer.variational.datasets.SimpleImageFolder(root, loader=None, extensions=None, transform=None, target_transform=None)[source]
class torchbearer.variational.datasets.dSprites(root, download=False, transform=None)[source]
download()[source]
get_img_by_latent(latent_code)[source]

Returns the image defined by the latent code

Parameters:latent_code (list of int) – Latent code of length 6 defining each generative factor
Returns:Image defined by given code
load_data()[source]
torchbearer.variational.datasets.make_dataset(dir, extensions)[source]

Visualisation

class torchbearer.variational.visualisation.CodePathWalker(num_steps, p1, p2)[source]
vis(state)[source]

Create the tensor of images to be displayed

class torchbearer.variational.visualisation.ImagePathWalker(num_steps, im1, im2)[source]
vis(state)[source]

Create the tensor of images to be displayed

class torchbearer.variational.visualisation.LatentWalker(same_image, row_size)[source]
for_data(data_key)[source]
Parameters:data_key (StateKey) – State key which will contain data to act on
Returns:self
Return type:LatentWalker
for_space(space_id)[source]

Sets the ID for which latent space to vary when model outputs [latent_space_0, latent_space_1, …]

Parameters:space_id (int) – ID of the latent space to vary
Returns:self
Return type:LatentWalker
on_train()[source]

Sets the walker to run during training

Returns:self
Return type:LatentWalker
on_val()[source]

Sets the walker to run during validation

Returns:self
Return type:LatentWalker
to_file(file)[source]
Parameters:file (string, pathlib.Path object or file object) – File in which result is saved
Returns:self
Return type:LatentWalker
to_key(state_key)[source]
Parameters:state_key (StateKey) – State key under which to store result
Returns:self
Return type:LatentWalker
vis(state)[source]

Create the tensor of images to be displayed

class torchbearer.variational.visualisation.LinSpaceWalker(lin_start=-1, lin_end=1, lin_steps=8, dims_to_walk=[0], zero_init=False, same_image=False)[source]
vis(state)[source]

Create the tensor of images to be displayed

class torchbearer.variational.visualisation.RandomWalker(var=1, num_images=32, uniform=False, row_size=8)[source]
vis(state)[source]

Create the tensor of images to be displayed

class torchbearer.variational.visualisation.ReconstructionViewer(row_size=8, recon_key=y_pred)[source]
vis(state)[source]

Create the tensor of images to be displayed