Training a GAN¶
We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN). This tutorial is a modified version of the GAN from the brilliant collection of GAN implementations PyTorch_GAN by eriklindernoren on github.
Data and Constants¶
We first define all constants for the example.
epochs = 200 batch_size = 64 lr = 0.0002 nworkers = 8 latent_dim = 100 sample_interval = 400 img_shape = (1, 28, 28) adversarial_loss = torch.nn.BCELoss() device = 'cuda' valid = torch.ones(batch_size, 1, device=device) fake = torch.zeros(batch_size, 1, device=device)
We then define a number of state keys for convenience. This is optional, however, it automatically avoids key conflicts.
GEN_IMGS = state_key('gen_imgs') DISC_GEN = state_key('disc_gen') DISC_GEN_DET = state_key('disc_gen_det') DISC_REAL = state_key('disc_real') G_LOSS = state_key('g_loss') D_LOSS = state_key('d_loss')
We then define the dataset and dataloader - for this example, MNIST.
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = datasets.MNIST('./data/mnist', train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
We use the generator and discriminator from PyTorch_GAN and combine them into a model that performs a single forward pass.
class GAN(nn.Module): def __init__(self): super().__init__() self.discriminator = Discriminator() self.generator = Generator() def forward(self, real_imgs, state): # Generator Forward z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape, latent_dim)))).to(state[tb.DEVICE]) state[GEN_IMGS] = self.generator(z) state[DISC_GEN] = self.discriminator(state[GEN_IMGS]) # We don't want to keep discriminator gradients on the generator forward pass self.discriminator.zero_grad() # Discriminator Forward state[DISC_GEN_DET] = self.discriminator(state[GEN_IMGS].detach()) state[DISC_REAL] = self.discriminator(real_imgs)
Note that we have to be careful to remove the gradient information from the discriminator after doing the generator forward pass.
Since our loss is complicated in this example, we shall forgo the basic loss criterion used in normal torchbearer models.
def zero_loss(y_pred, y_true): return torch.zeros(y_true.shape, 1)
Instead use a callback to provide the loss. Since this callback is very simple we can use callback decorators on a function (which takes state) to tell torchbearer when it should be called.
@callbacks.on_criterion def loss_callback(state): fake_loss = adversarial_loss(state[DISC_GEN_DET], fake) real_loss = adversarial_loss(state[DISC_REAL], valid) state[G_LOSS] = adversarial_loss(state[DISC_GEN], valid) state[D_LOSS] = (real_loss + fake_loss) / 2 # This is the loss that backward is called on. state[tb.LOSS] = state[G_LOSS] + state[D_LOSS]
Note that we have summed the separate discriminator and generator losses since their graphs are separated, this is allowable.
We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition. We can then create metrics from these by decorating simple state fetcher metrics.
@tb.metrics.running_mean @tb.metrics.mean class g_loss(tb.metrics.Metric): def __init__(self): super().__init__(G_LOSS) def process(self, state): return state[G_LOSS]
We can then train the torchbearer model on the GPU in the standard way.
torchbearermodel = tb.Model(model, optim, zero_loss, ['loss', g_loss(), d_loss()]) torchbearermodel.to(device) torchbearermodel.fit_generator(dataloader, epochs=200, pass_state=True, callbacks=[loss_callback, saver_callback])
We borrow the image saving method from PyTorch_GAN and put it in a call back to save on training step - again using decorators.
@callbacks.on_step_training def saver_callback(state): batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH] if batches_done % sample_interval == 0: save_image(state[GEN_IMGS].data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)
After 172400 iterations we see the following.