import torch
import torchbearer as tb
import torchbearer.callbacks as c
[docs]class LatentWalker(c.Callback):
def __init__(self, same_image, row_size):
"""
Args:
same_image (bool): If True, use the same image for all latent dimension walks. Else each dimension has different image
row_size (int): Number of images displayed in each row of the grid.
"""
super(LatentWalker, self).__init__()
self.data_key = None
self.same_image = same_image
self.row_size = row_size
self.model = None
self.data = None
self.dev = None
self.file = None
self.store_key = None
self.variable_space = 0
[docs] def on_train(self):
"""
Sets the walker to run during training
Returns:
LatentWalker: self
"""
self.on_step_training = c.once_per_epoch(self._vis)
return self
[docs] def on_val(self):
"""
Sets the walker to run during validation
Returns:
LatentWalker: self
"""
self.on_step_validation = c.once_per_epoch(self._vis)
return self
[docs] def for_space(self, space_id):
"""
Sets the ID for which latent space to vary when model outputs [latent_space_0, latent_space_1, ...]
Args:
space_id (int): ID of the latent space to vary
Returns:
LatentWalker: self
"""
self.variable_space = space_id
return self
[docs] def for_data(self, data_key):
"""
Args:
data_key (:class:`.StateKey`): State key which will contain data to act on
Returns:
LatentWalker: self
"""
self.data_key = data_key
return self
[docs] def to_key(self, state_key):
"""
Args:
state_key (:class:`.StateKey`): State key under which to store result
Returns:
LatentWalker: self
"""
self.store_key = state_key
return self
[docs] def to_file(self, file):
"""
Args:
file (string, pathlib.Path object or file object): File in which result is saved
Returns:
LatentWalker: self
"""
self.file = file
return self
def _vis(self, state):
self.model = state[tb.MODEL]
self.data = state[self.data_key] if self.data_key is not None else state[tb.X]
self.dev = state[tb.DEVICE]
with torch.no_grad():
result = self.vis(state)
if self.file is not None:
self._save_walk(result)
if self.store_key is not None:
state[self.store_key] = result
[docs] def vis(self, state):
"""
Create the tensor of images to be displayed
"""
raise NotImplementedError
def _save_walk(self, tensor):
from torchvision.utils import save_image
save_image(tensor, self.file, self.row_size, normalize=True, pad_value=1)
[docs]class ReconstructionViewer(LatentWalker):
def __init__(self, row_size=8, recon_key=tb.Y_PRED):
"""
Latent space walker that just returns the reconstructed images for the batch
Args:
row_size (int): Number of images displayed in each row of the grid.
recon_key (StateKey): :class:`.StateKey` of the reconstructed images
"""
super(ReconstructionViewer, self).__init__(False, row_size)
self.recon_key = recon_key
[docs] def vis(self, state):
data = self.data[:self.row_size]
recons = state[self.recon_key][:self.row_size]
return torch.cat([data, recons])
[docs]class LinSpaceWalker(LatentWalker):
def __init__(self, lin_start=-1, lin_end=1, lin_steps=8, dims_to_walk=[0], zero_init=False, same_image=False):
"""
Latent space walker that explores each dimension linearly from start to end points
Args:
lin_start (float): Starting point of linspace
lin_end (float): End point of linspace
lin_steps (int): Number of steps to take in linspace
dims_to_walk (list of int): List of dimensions to walk
zero_init (bool): If True, dimensions not being walked are 0. Else, they are obtained from encoder
same_image (bool): If True, use same image for each dimension walked. Else, use different images
"""
super(LinSpaceWalker, self).__init__(same_image, lin_steps)
self.dims_to_walk = dims_to_walk
self.zero_init = zero_init
self.linspace = torch.linspace(lin_start, lin_end, lin_steps)
[docs] def vis(self, state):
self.linspace = self.linspace.to(self.dev)
num_images = self.row_size * len(self.dims_to_walk)
num_spaces = len(self.model.latent_dims)
if self.zero_init:
sample = []
for i in range(num_spaces):
sample.append(torch.zeros(num_images, self.model.latent_dims[i], device=self.dev).unsqueeze(1).repeat(1, self.row_size, 1))
elif self.same_image:
sample = list(self.model.encode(self.data[0], state))
for i in range(len(sample)):
sample[i] = sample[i][0].unsqueeze(0).unsqueeze(1).repeat(sample[i].shape[0], self.row_size, 1)
else:
sample = list(self.model.encode(self.data, state))
for i in range(len(sample)):
sample[i] = sample[i].unsqueeze(1).repeat(1, self.row_size, 1)
dims = self.dims_to_walk
i = 0
for dim in list(dims):
sample[self.variable_space][i, :, dim] = self.linspace
i += 1
for i in range(num_spaces):
sample[i] = sample[i].view(-1, self.model.latent_dims[i])[:num_images]
result = self.model.decode(sample).view(num_images, -1, self.data.shape[-2], self.data.shape[-1])
return result
[docs]class RandomWalker(LatentWalker):
def __init__(self, var=1, num_images=32, uniform=False, row_size=8):
"""
Latent space walker that shows random samples from latent space
Args:
var (float or torch.Tensor): Variance of random sample
num_images (int): Number of random images to sample
uniform (bool): If True, sample uniform distribution [-v, v). If False, sample normal distribution with var v
row_size (int): Number of images displayed in each row of the grid.
"""
super(RandomWalker, self).__init__(False, row_size)
self.num_images = num_images
self.uniform = uniform
self.var = var
[docs] def vis(self, state):
num_spaces = len(self.model.latent_dims)
sample = []
for i in range(num_spaces):
sample.append(torch.zeros(self.num_images, self.model.latent_dims[i], device=self.dev))
if self.uniform:
sample[self.variable_space] = (torch.rand(self.num_images, self.model.latent_dims[self.variable_space], device=self.dev)*2-1)*self.var
else:
sample[self.variable_space] = (torch.randn(self.num_images, self.model.latent_dims[self.variable_space], device=self.dev)*2-1)*self.var
result = self.model.decode(sample).view(sample[0].shape[0], -1, self.data.shape[-2], self.data.shape[-1])
return result
[docs]class CodePathWalker(LatentWalker):
def __init__(self, num_steps, p1, p2):
"""
Latent space walker that walks between two specified codes p1 and p2
Args:
num_steps (int): Number of steps to take between points
p1 (torch.Tensor): Batch of codes
p2 (torch.Tensor): Batch of codes
"""
super(CodePathWalker, self).__init__(True, num_steps)
self.p1 = p1
self.p2 = p2
self.num_steps = num_steps
[docs] def vis(self, state):
step_sizes = (self.p1 - self.p2)/(self.num_steps-1)
codes = torch.zeros(self.p1.shape[0], self.num_steps, self.p1.shape[1]).to(self.dev)
for i in range(self.num_steps):
codes[:, i] = self.p1 - step_sizes*i
codes = codes.view(-1, self.p1.shape[1])
result = self.model.decode(codes).view(codes.shape[0], -1, self.data.shape[-2], self.data.shape[-1])
return result
[docs]class ImagePathWalker(CodePathWalker):
def __init__(self, num_steps, im1, im2):
"""
Latent space walker that walks between two specified images im1 and im2
Args:
num_steps (int): Number of steps to take between points
im1 (torch.Tensor): Batch of images
im2 (torch.Tensor): Batch of images
"""
super(ImagePathWalker, self).__init__(num_steps, None, None)
self.im1, self.im2 = im1, im2
[docs] def vis(self, state):
self.p1 = self.model.encode(self.im1.to(self.dev), state)
self.p2 = self.model.encode(self.im2.to(self.dev), state)
return super(ImagePathWalker, self).vis(state)