Training a GAN

We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN). This tutorial is a modified version of the GAN from the brilliant collection of GAN implementations PyTorch_GAN by eriklindernoren on github.

Data and Constants

We first define all constants for the example.

epochs = 200
batch_size = 64
lr = 0.0002
nworkers = 8
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)
adversarial_loss = torch.nn.BCELoss()
device = 'cuda'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)

We then define a number of state keys for convenience using state_key(). This is optional, however, it automatically avoids key conflicts.

GEN_IMGS = state_key('gen_imgs')
DISC_GEN = state_key('disc_gen')
DISC_GEN_DET = state_key('disc_gen_det')
DISC_REAL = state_key('disc_real')
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')

We then define the dataset and dataloader - for this example, MNIST.

os.makedirs('./data/mnist', exist_ok=True)
transform = transforms.Compose([
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

dataset = datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)

dataloader =, batch_size=batch_size, shuffle=True, drop_last=True)


We use the generator and discriminator from PyTorch_GAN and combine them into a model that performs a single forward pass.

class GAN(nn.Module):
    def __init__(self):
        self.discriminator = Discriminator()
        self.generator = Generator()

    def forward(self, real_imgs, state):
        # Generator Forward
        z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])
        state[GEN_IMGS] = self.generator(z)
        state[DISC_GEN] = self.discriminator(state[GEN_IMGS])
        # This clears the function graph built up for the discriminator

        # Discriminator Forward
        state[DISC_GEN_DET] = self.discriminator(state[GEN_IMGS].detach())
        state[DISC_REAL] = self.discriminator(real_imgs)

Note that we have to be careful to remove the gradient information from the discriminator after doing the generator forward pass.


Since our loss computation in this example is complicated, we shall forgo the basic loss criterion used in normal torchbearer trials. Instead we use a callback to provide the loss, in this case we use the add_to_loss() callback decorator. This decorates a function that returns a loss and automatically adds this to the main loss in training and validation.

def loss_callback(state):
    fake_loss = adversarial_loss(state[DISC_GEN_DET], fake)
    real_loss = adversarial_loss(state[DISC_REAL], valid)
    state[G_LOSS] = adversarial_loss(state[DISC_GEN], valid)
    state[D_LOSS] = (real_loss + fake_loss) / 2
    return state[G_LOSS] + state[D_LOSS]

Note that we have summed the separate discriminator and generator losses, since their graphs are separated, this is allowable.


We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition. We can then create metrics from these by decorating simple state fetcher metrics.

class g_loss(tb.metrics.Metric):
    def __init__(self):

    def process(self, state):
        return state[G_LOSS]


We can then train the torchbearer trial on the GPU in the standard way. Note that when torchbearer is passed a None criterion it automatically sets the base loss to 0.

torchbearertrial = tb.Trial(model, optim, criterion=None, metrics=['loss', g_loss(), d_loss()],
                            callbacks=[loss_callback, saver_callback], pass_state=True)


We borrow the image saving method from PyTorch_GAN and put it in a call back to save on_step_training(). We generate from the same inputs each time to get a better visulisation.

batch = torch.randn(25, latent_dim).to(device)
def saver_callback(state):
    batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH]
    if batches_done % sample_interval == 0:
        samples = state[tb.MODEL].generator(batch)
        save_image(samples, 'images/%d.png' % batches_done, nrow=5, normalize=True)

Here is a Gif created from the saved images.

GAN generated samples

Source Code

The source code for the example is given below: