torchbearer.variational¶
Distributions¶
The distributions module is an extension of the torch.distributions package intended to facilitate implementations
required for specific variational approaches through the SimpleDistribution
class. Generally, using a
torch.distributions.Distribution
object should be preferred over a SimpleDistribution
, for better
argument validation and more complete implementations. However, if you need to implement something new for a specific
variational approach, then a SimpleDistribution
may be more forgiving. Furthermore, you may find it easier
to understand the function of the implementations here.
-
class
torchbearer.variational.distributions.
SimpleDistribution
(batch_shape=<sphinx.ext.autodoc.importer._MockObject object>, event_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob function is not differentiable with respect to the distribution parameters or the given value, then this should be mentioned in the documentation.
-
arg_constraints
¶
-
has_rsample
= True¶
-
log_prob
(value)[source]¶ Returns the log of the probability density/mass function evaluated at value. :param value: Value at which to evaluate log probabilty :type value: torch.Tensor, Number
-
mean
¶
-
rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.
-
support
¶
-
variance
¶
-
-
class
torchbearer.variational.distributions.
SimpleExponential
(lograte)[source]¶ The SimpleExponential class is a
SimpleDistribution
which implements a straight forward Exponential distribution with the given lograte. This performs significantly fewer checks than torch.distributions.Exponential , but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed in a stable fashion, without taking a logarithm.Parameters: lograte (torch.Tensor, Number) – The natural log of the rate of the distribution, numbers will be cast to tensors -
log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. The log_prob for this distribution is fully differentiable and has stable gradient since we use the lograte here.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for an Exponential distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per lograte given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
-
-
class
torchbearer.variational.distributions.
SimpleNormal
(mu, logvar)[source]¶ The SimpleNormal class is a
SimpleDistribution
which implements a straight forward Normal / Gaussian distribution. This performs significantly fewer checks than torch.distributions.Normal, but should be sufficient for the purpose of implementing a VAE.Parameters: - mu (torch.Tensor, Number) – The mean of the distribution, numbers will be cast to tensors
- logvar (torch.Tensor, Number) – The log variance of the distribution, numbers will be cast to tensors
-
log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. Since the density of a Gaussian is differentiable, this function is differentiable.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Normal distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per mean / variance given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
-
class
torchbearer.variational.distributions.
SimpleUniform
(low, high)[source]¶ The SimpleUniform class is a
SimpleDistribution
which implements a straight forward Uniform distribution in the interval[low, high)
. This performs significantly fewer checks than torch.distributions.Uniform, but should be sufficient for the purpose of implementing a VAE.Parameters: - low (torch.Tensor, Number) – The lower range of the distribution (inclusive), numbers will be cast to tensors
- high (torch.Tensor, Number) – The upper range of the distribution (exclusive), numbers will be cast to tensors
-
log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. Since this distribution is uniform, the log probability is
-log(high - low)
for all values in the range[low, high)
and -inf elsewhere. This function is therefore only piecewise differentiable.Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Uniform distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per low / high given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
-
class
torchbearer.variational.distributions.
SimpleWeibull
(l, k)[source]¶ The SimpleWeibull class is a
SimpleDistribution
which implements a straight forward Weibull distribution. This performs significantly fewer checks than torch.distributions.Weibull, but should be sufficient for the purpose of implementing a VAE.@article{squires2019a, title={A Variational Autoencoder for Probabilistic Non-Negative Matrix Factorisation}, author={Steven Squires and Adam Prugel-Bennett and Mahesan Niranjan}, year={2019} }
Parameters: - l (torch.Tensor, Number) – The scale parameter of the distribution, numbers will be cast to tensors
- k (torch.Tensor, Number) – The shape parameter of the distribution, numbers will be cast to tensors
-
log_prob
(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. This function is differentiable and its log probability is -inf for values less than 0.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample
(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Weibull distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per k / lambda given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
Divergences¶
-
class
torchbearer.variational.divergence.
DivergenceBase
(keys, state_key=None)[source]¶ The
DivergenceBase
class is an abstract base class which defines a series of useful methods for dealing with divergences. The keys dict given on init is used to map objects in state to kwargs in the compute function.Parameters: -
compute
(**kwargs)[source]¶ Compute the loss with the given kwargs defined in the constructor.
Parameters: kwargs – The bound kwargs, taken from state with the keys given in the constructor Returns: The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
-
on_criterion
(state)[source]¶ Perform some action with the given state as context after the criterion has been evaluated.
Parameters: state (dict) – The current state dict of the Trial
.
-
on_criterion_validation
(state)[source]¶ Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.
Parameters: state (dict) – The current state dict of the Trial
.
-
with_beta
(beta)[source]¶ Multiply the divergence by the given beta, as introduced by beta-vae.
@article{higgins2016beta, title={beta-vae: Learning basic visual concepts with a constrained variational framework}, author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander}, year={2016} }
Parameters: beta (float) – The beta (> 1) to multiply by. Returns: self Return type: Divergence
-
with_linear_capacity
(min_c=0, max_c=25, steps=100000, gamma=1000)[source]¶ Limit divergence by capacity, linearly increased from min_c to max_c for steps, as introduced in Understanding disentangling in beta-VAE.
@article{burgess2018understanding, title={Understanding disentangling in beta-vae}, author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander}, journal={arXiv preprint arXiv:1804.03599}, year={2018} }
Parameters: - min_c (float) – Minimum capacity
- max_c (float) – Maximum capacity
- steps (int) – Number of steps to increase over
- gamma (float) – Multiplicative gamma, usually a high number
Returns: self
Return type: Divergence
-
with_post_function
(post_fcn)[source]¶ Register the given post function, to be applied after to loss after reduction.
Parameters: post_fcn – A function of loss which applies some operation (e.g. multiplying by beta) Returns: self Return type: Divergence
-
with_reduction
(reduction_fcn)[source]¶ Override the reduction operation with the given function, use this if your divergence doesn’t output a two dimensional tensor.
Parameters: reduction_fcn – The function to be applied to the divergence output and return a single value Returns: self Return type: Divergence
-
-
class
torchbearer.variational.divergence.
SimpleExponentialSimpleExponentialKL
(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleExponential (or similar) distributions.
Note
The distribution object must have lograte attribute
-
class
torchbearer.variational.divergence.
SimpleNormalSimpleNormalKL
(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleNormal (or similar) distributions.
Note
The distribution objects must have mu and logvar attributes
Parameters:
-
class
torchbearer.variational.divergence.
SimpleNormalUnitNormalKL
(input_key, state_key=None)[source]¶ A KL divergence between a SimpleNormal (or similar) instance and a fixed unit normal (N[0, 1]) target.
Note
The distribution object must have mu and logvar attributes
Parameters:
-
class
torchbearer.variational.divergence.
SimpleWeibullSimpleWeibullKL
(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleWeibull (or similar) distributions.
Note
The distribution object must have lambda (scale) and k (shape) attributes@article{DBLP:journals/corr/Bauckhage14, author = {Christian Bauckhage}, title = {Computing the Kullback-Leibler Divergence between two Generalized Gamma Distributions}, journal = {CoRR}, volume = {abs/1401.6853}, year = {2014} }
Auto-Encoding¶
-
class
torchbearer.variational.auto_encoder.
AutoEncoderBase
(latent_dims)[source]¶ -
decode
(sample, state=None)[source]¶ Decode the given latent space sample batch to images.
Parameters: - sample – The latent space samples
- state – The trial state
Returns: Decoded images
-
Datasets¶
-
class
torchbearer.variational.datasets.
SimpleImageFolder
(root, loader=None, extensions=None, transform=None, target_transform=None)[source]¶
-
class
torchbearer.variational.datasets.
dSprites
(root, download=False, transform=None)[source]¶
Visualisation¶
-
class
torchbearer.variational.visualisation.
LatentWalker
(same_image, row_size)[source]¶ -
for_data
(data_key)[source]¶ Parameters: data_key ( StateKey
) – State key which will contain data to act onReturns: self Return type: LatentWalker
-
for_space
(space_id)[source]¶ Sets the ID for which latent space to vary when model outputs [latent_space_0, latent_space_1, …]
Parameters: space_id (int) – ID of the latent space to vary Returns: self Return type: LatentWalker
-
on_train
()[source]¶ Sets the walker to run during training
Returns: self Return type: LatentWalker
-
on_val
()[source]¶ Sets the walker to run during validation
Returns: self Return type: LatentWalker
-
to_file
(file)[source]¶ Parameters: file (string, pathlib.Path object or file object) – File in which result is saved Returns: self Return type: LatentWalker
-
to_key
(state_key)[source]¶ Parameters: state_key ( StateKey
) – State key under which to store resultReturns: self Return type: LatentWalker
-
-
class
torchbearer.variational.visualisation.
LinSpaceWalker
(lin_start=-1, lin_end=1, lin_steps=8, dims_to_walk=[0], zero_init=False, same_image=False)[source]¶