Using DistributedDataParallel with Torchbearer on CPU

This note will quickly cover how we can use torchbearer to train over multiple nodes. We shall do this by training a simple model to classify and for a massive amount of overkill we will be doing this on MNIST. Most of the code for this example is based off the Distributed Data Parallel (DDP) tutorial and the imagenet example from the PyTorch docs. We recommend you read at least the DDP tutorial before continuing with this note.

Setup, Cleanup and Model

We keep similar setup, cleanup and model from the DDP tutorial. All that is changed is taking rank, world size and master address from terminal arguments and changing the model to apply to MNIST. Note that we are keeping to the GLOO backend since this part of the note will be purely on the CPU.

def setup():
    os.environ['MASTER_ADDR'] = args.master
    os.environ['MASTER_PORT'] = '29500'

    # initialize the process group
    dist.init_process_group("gloo", rank=args.rank, world_size=args.world_size)

    # Explicitly setting seed makes sure that models created in two processes
    # start from same random weights and biases. Alternatively, sync models
    # on start with the callback below.
    #torch.manual_seed(42)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(784, 100)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(100, 10)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

Sync Methods

Since we are working across multiple machines we need a way to synchronise the model itself and its gradients. To do this we utilise methods similar to that of the distributed applications tutorial from PyTorch.

def sync_model(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
        param.data /= size


def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size

Since we require the gradients to be synced every step we implement both of these methods as Torchbearer callbacks. We sync the model itself on init and sync the gradients every step after the backward call.

@torchbearer.callbacks.on_init
def sync(state):
    sync_model(state[torchbearer.MODEL])


@torchbearer.callbacks.on_backward
def grad(state):
    average_gradients(state[torchbearer.MODEL])

Worker Function

Now we need to define the main worker function that each process will be running. We need this to setup the environment, actually run the training process and cleanup the environment after we finish. This function outside of calling setup and cleanup is exactly the same as any Torchbearer training function.

def worker():
    setup()
    print("Rank and node: {}-{}".format(args.rank, platform.node()))

    model = ToyModel().to('cpu')
    ddp_model = DDP(model)

    kwargs = {}

    ds = datasets.MNIST('./data/mnist/', train=True, download=True,
         transform=transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))
          ]))

    train_sampler = torch.utils.data.distributed.DistributedSampler(ds)
    train_loader = torch.utils.data.DataLoader(ds,
        batch_size=128, sampler=train_sampler, **kwargs)

    test_ds = datasets.MNIST('./data/mnist', train=False,
              transform=transforms.Compose([
                 transforms.ToTensor(),
                 transforms.Normalize((0.1307,), (0.3081,))
                 ]))
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_ds)
    test_loader = torch.utils.data.DataLoader(test_ds,
        batch_size=128, sampler=test_sampler,  **kwargs)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    trial = torchbearer.Trial(ddp_model, optimizer, loss_fn, metrics=['loss', 'acc'],
        callbacks=[sync, grad, flatten])
    trial.with_train_generator(train_loader)
    trial.run(10, verbose=2)

    print("Model hash: {}".format(hash(model)))
    print('First parameter: {}'.format(next(model.parameters())))

    cleanup()

You might have noticed that we had an extra flatten callback in the Trial, the only purpose of this was to flatten each image.

@torchbearer.callbacks.on_sample
def flatten(state):
    state[torchbearer.X] = state[torchbearer.X].view(state[torchbearer.X].shape[0], -1)

Running

All we need to do now is write a __main__ function to run the worker function.

if __name__ == "__main__":
    worker()
    print('done')

We can then ssh into each node on which we want to run the training and run the following code replacing i with the rank of each process.

python distributed_data_parallel.py --world-size 2 --rank i --host (host address)

Running on machines with GPUs

Coming soon.

Source Code

The source code for this example is given below: