torchbearer.variational¶
Distributions¶
The distributions module is an extension of the torch.distributions package intended to facilitate implementations
required for specific variational approaches through the SimpleDistribution
class. Generally, using a
torch.distributions.Distribution
object should be preferred over a SimpleDistribution
, for better
argument validation and more complete implementations. However, if you need to implement something new for a specific
variational approach, then a SimpleDistribution
may be more forgiving. Furthermore, you may find it easier
to understand the function of the implementations here.

class
torchbearer.variational.distributions.
SimpleDistribution
(batch_shape=<sphinx.ext.autodoc.importer._MockObject object>, event_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob function is not differentiable with respect to the distribution parameters or the given value, then this should be mentioned in the documentation.

arg_constraints
¶

has_rsample
= True¶

log_prob
(value)[source]¶ Returns the log of the probability density/mass function evaluated at value. :param value: Value at which to evaluate log probabilty :type value: torch.Tensor, Number

mean
¶

rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.

support
¶

variance
¶


class
torchbearer.variational.distributions.
SimpleExponential
(lograte)[source]¶ The SimpleExponential class is a
SimpleDistribution
which implements a straight forward Exponential distribution with the given lograte. This performs significantly fewer checks than torch.distributions.Exponential , but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed in a stable fashion, without taking a logarithm.Parameters: lograte (torch.Tensor, Number) – The natural log of the rate of the distribution, numbers will be cast to tensors 
log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. The log_prob for this distribution is fully differentiable and has stable gradient since we use the lograte here.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution

rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for an Exponential distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per lograte given) Returns: A reparameterized sample with gradient with respect to the distribution parameters


class
torchbearer.variational.distributions.
SimpleNormal
(mu, logvar)[source]¶ The SimpleNormal class is a
SimpleDistribution
which implements a straight forward Normal / Gaussian distribution. This performs significantly fewer checks than torch.distributions.Normal, but should be sufficient for the purpose of implementing a VAE.Parameters:  mu (torch.Tensor, Number) – The mean of the distribution, numbers will be cast to tensors
 logvar (torch.Tensor, Number) – The log variance of the distribution, numbers will be cast to tensors

log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. Since the density of a Gaussian is differentiable, this function is differentiable.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution

rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Normal distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per mean / variance given) Returns: A reparameterized sample with gradient with respect to the distribution parameters

class
torchbearer.variational.distributions.
SimpleUniform
(low, high)[source]¶ The SimpleUniform class is a
SimpleDistribution
which implements a straight forward Uniform distribution in the interval[low, high)
. This performs significantly fewer checks than torch.distributions.Uniform, but should be sufficient for the purpose of implementing a VAE.Parameters:  low (torch.Tensor, Number) – The lower range of the distribution (inclusive), numbers will be cast to tensors
 high (torch.Tensor, Number) – The upper range of the distribution (exclusive), numbers will be cast to tensors

log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. Since this distribution is uniform, the log probability is
log(high  low)
for all values in the range[low, high)
and inf elsewhere. This function is therefore only piecewise differentiable.Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution

rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Uniform distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per low / high given) Returns: A reparameterized sample with gradient with respect to the distribution parameters

class
torchbearer.variational.distributions.
SimpleWeibull
(l, k)[source]¶ The SimpleWeibull class is a
SimpleDistribution
which implements a straight forward Weibull distribution. This performs significantly fewer checks than torch.distributions.Weibull, but should be sufficient for the purpose of implementing a VAE.@article{squires2019a, title={A Variational Autoencoder for Probabilistic NonNegative Matrix Factorisation}, author={Steven Squires and Adam PrugelBennett and Mahesan Niranjan}, year={2019} }
Parameters:  l (torch.Tensor, Number) – The scale parameter of the distribution, numbers will be cast to tensors
 k (torch.Tensor, Number) – The shape parameter of the distribution, numbers will be cast to tensors

log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. This function is differentiable and its log probability is inf for values less than 0.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution

rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Weibull distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per k / lambda given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
Divergences¶

class
torchbearer.variational.divergence.
DivergenceBase
(keys, state_key=None)[source]¶ The
DivergenceBase
class is an abstract base class which defines a series of useful methods for dealing with divergences. The keys dict given on init is used to map objects in state to kwargs in the compute function.Parameters: 
compute
(**kwargs)[source]¶ Compute the loss with the given kwargs defined in the constructor.
Parameters: kwargs – The bound kwargs, taken from state with the keys given in the constructor Returns: The calculated divergence as a two dimensional tensor (batch, distribution dimensions)

on_criterion
(state)[source]¶ Perform some action with the given state as context after the criterion has been evaluated.
Parameters: state (dict) – The current state dict of the Trial
.

on_criterion_validation
(state)[source]¶ Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.
Parameters: state (dict) – The current state dict of the Trial
.

with_beta
(beta)[source]¶ Multiply the divergence by the given beta, as introduced by betavae.
@article{higgins2016beta, title={betavae: Learning basic visual concepts with a constrained variational framework}, author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander}, year={2016} }
Parameters: beta (float) – The beta (> 1) to multiply by. Returns: self Return type: Divergence

with_linear_capacity
(min_c=0, max_c=25, steps=100000, gamma=1000)[source]¶ Limit divergence by capacity, linearly increased from min_c to max_c for steps, as introduced in Understanding disentangling in betaVAE.
@article{burgess2018understanding, title={Understanding disentangling in betavae}, author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander}, journal={arXiv preprint arXiv:1804.03599}, year={2018} }
Parameters:  min_c (float) – Minimum capacity
 max_c (float) – Maximum capacity
 steps (int) – Number of steps to increase over
 gamma (float) – Multiplicative gamma, usually a high number
Returns: self
Return type: Divergence

with_post_function
(post_fcn)[source]¶ Register the given post function, to be applied after to loss after reduction.
Parameters: post_fcn – A function of loss which applies some operation (e.g. multiplying by beta) Returns: self Return type: Divergence

with_reduction
(reduction_fcn)[source]¶ Override the reduction operation with the given function, use this if your divergence doesn’t output a two dimensional tensor.
Parameters: reduction_fcn – The function to be applied to the divergence output and return a single value Returns: self Return type: Divergence


class
torchbearer.variational.divergence.
SimpleExponentialSimpleExponentialKL
(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleExponential (or similar) distributions.
Note
The distribution object must have lograte attribute

class
torchbearer.variational.divergence.
SimpleNormalSimpleNormalKL
(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleNormal (or similar) distributions.
Note
The distribution objects must have mu and logvar attributes
Parameters:

class
torchbearer.variational.divergence.
SimpleNormalUnitNormalKL
(input_key, state_key=None)[source]¶ A KL divergence between a SimpleNormal (or similar) instance and a fixed unit normal (N[0, 1]) target.
Note
The distribution object must have mu and logvar attributes
Parameters:

class
torchbearer.variational.divergence.
SimpleWeibullSimpleWeibullKL
(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleWeibull (or similar) distributions.
Note
The distribution object must have lambda (scale) and k (shape) attributes@article{DBLP:journals/corr/Bauckhage14, author = {Christian Bauckhage}, title = {Computing the KullbackLeibler Divergence between two Generalized Gamma Distributions}, journal = {CoRR}, volume = {abs/1401.6853}, year = {2014} }
AutoEncoding¶

class
torchbearer.variational.auto_encoder.
AutoEncoderBase
(latent_dims)[source]¶ 
decode
(sample, state=None)[source]¶ Decode the given latent space sample batch to images.
Parameters:  sample – The latent space samples
 state – The trial state
Returns: Decoded images

Datasets¶

class
torchbearer.variational.datasets.
SimpleImageFolder
(root, loader=None, extensions=None, transform=None, target_transform=None)[source]¶

class
torchbearer.variational.datasets.
dSprites
(root, download=False, transform=None)[source]¶
Visualisation¶

class
torchbearer.variational.visualisation.
LatentWalker
(same_image, row_size)[source]¶ 
for_data
(data_key)[source]¶ Parameters: data_key ( StateKey
) – State key which will contain data to act onReturns: self Return type: LatentWalker

for_space
(space_id)[source]¶ Sets the ID for which latent space to vary when model outputs [latent_space_0, latent_space_1, …]
Parameters: space_id (int) – ID of the latent space to vary Returns: self Return type: LatentWalker

on_train
()[source]¶ Sets the walker to run during training
Returns: self Return type: LatentWalker

on_val
()[source]¶ Sets the walker to run during validation
Returns: self Return type: LatentWalker

to_file
(file)[source]¶ Parameters: file (string, pathlib.Path object or file object) – File in which result is saved Returns: self Return type: LatentWalker

to_key
(state_key)[source]¶ Parameters: state_key ( StateKey
) – State key under which to store resultReturns: self Return type: LatentWalker


class
torchbearer.variational.visualisation.
LinSpaceWalker
(lin_start=1, lin_end=1, lin_steps=8, dims_to_walk=[0], zero_init=False, same_image=False)[source]¶