Quickstart Guide¶
This guide will give a quick intro to training PyTorch models with torchbearer. We’ll start by loading in some data and defining a model, then we’ll train it for a few epochs and see how well it does.
Defining the Model¶
Let’s get using torchbearer. Here’s some data from Cifar10 and a simple 3 layer strided CNN:
BATCH_SIZE = 128
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
trainset = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]))
traingen = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)
testset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]))
testgen = torch.utils.data.DataLoader(testset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=False, num_workers=10)
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(3, 16, stride=2, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, stride=2, kernel_size=3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, stride=2, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.classifier = nn.Linear(576, 10)
def forward(self, x):
x = self.convs(x)
x = x.view(-1, 576)
return self.classifier(x)
model = SimpleModel()
Typically we would need a training loop and a series of calls to backward, step etc. Instead, with torchbearer, we can define our optimiser and some metrics (just ‘acc’ and ‘loss’ for now) and let it do the work.
Training on Cifar10¶
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = nn.CrossEntropyLoss()
from torchbearer import Model
torchbearer_model = Model(model, optimizer, loss, metrics=['acc', 'loss']).to('cuda')
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen)
Running the above produces the following output:
Files already downloaded and verified
Files already downloaded and verified
0/10(t): 100%|██████████| 391/391 [00:01<00:00, 211.19it/s, running_acc=0.549, running_loss=1.25, acc=0.469, acc_std=0.499, loss=1.48, loss_std=0.238]
0/10(v): 100%|██████████| 79/79 [00:00<00:00, 265.14it/s, val_acc=0.556, val_acc_std=0.497, val_loss=1.25, val_loss_std=0.0785]
1/10(t): 100%|██████████| 391/391 [00:01<00:00, 209.80it/s, running_acc=0.61, running_loss=1.09, acc=0.593, acc_std=0.491, loss=1.14, loss_std=0.0968]
1/10(v): 100%|██████████| 79/79 [00:00<00:00, 227.97it/s, val_acc=0.593, val_acc_std=0.491, val_loss=1.14, val_loss_std=0.0865]
2/10(t): 100%|██████████| 391/391 [00:01<00:00, 220.70it/s, running_acc=0.656, running_loss=0.972, acc=0.645, acc_std=0.478, loss=1.01, loss_std=0.0915]
2/10(v): 100%|██████████| 79/79 [00:00<00:00, 218.91it/s, val_acc=0.631, val_acc_std=0.482, val_loss=1.04, val_loss_std=0.0951]
3/10(t): 100%|██████████| 391/391 [00:01<00:00, 208.67it/s, running_acc=0.682, running_loss=0.906, acc=0.675, acc_std=0.468, loss=0.922, loss_std=0.0895]
3/10(v): 100%|██████████| 79/79 [00:00<00:00, 86.95it/s, val_acc=0.657, val_acc_std=0.475, val_loss=0.97, val_loss_std=0.0925]
4/10(t): 100%|██████████| 391/391 [00:01<00:00, 211.22it/s, running_acc=0.693, running_loss=0.866, acc=0.699, acc_std=0.459, loss=0.86, loss_std=0.092]
4/10(v): 100%|██████████| 79/79 [00:00<00:00, 249.74it/s, val_acc=0.662, val_acc_std=0.473, val_loss=0.957, val_loss_std=0.093]
5/10(t): 100%|██████████| 391/391 [00:01<00:00, 205.12it/s, running_acc=0.71, running_loss=0.826, acc=0.713, acc_std=0.452, loss=0.818, loss_std=0.0904]
5/10(v): 100%|██████████| 79/79 [00:00<00:00, 230.12it/s, val_acc=0.661, val_acc_std=0.473, val_loss=0.962, val_loss_std=0.0966]
6/10(t): 100%|██████████| 391/391 [00:01<00:00, 210.87it/s, running_acc=0.714, running_loss=0.81, acc=0.727, acc_std=0.445, loss=0.779, loss_std=0.0904]
6/10(v): 100%|██████████| 79/79 [00:00<00:00, 241.95it/s, val_acc=0.667, val_acc_std=0.471, val_loss=0.952, val_loss_std=0.11]
7/10(t): 100%|██████████| 391/391 [00:01<00:00, 209.94it/s, running_acc=0.727, running_loss=0.791, acc=0.74, acc_std=0.439, loss=0.747, loss_std=0.0911]
7/10(v): 100%|██████████| 79/79 [00:00<00:00, 223.23it/s, val_acc=0.673, val_acc_std=0.469, val_loss=0.938, val_loss_std=0.107]
8/10(t): 100%|██████████| 391/391 [00:01<00:00, 203.16it/s, running_acc=0.747, running_loss=0.736, acc=0.752, acc_std=0.432, loss=0.716, loss_std=0.0899]
8/10(v): 100%|██████████| 79/79 [00:00<00:00, 221.55it/s, val_acc=0.679, val_acc_std=0.467, val_loss=0.923, val_loss_std=0.113]
9/10(t): 100%|██████████| 391/391 [00:01<00:00, 213.23it/s, running_acc=0.756, running_loss=0.701, acc=0.759, acc_std=0.428, loss=0.695, loss_std=0.0915]
9/10(v): 100%|██████████| 79/79 [00:00<00:00, 245.33it/s, val_acc=0.676, val_acc_std=0.468, val_loss=0.951, val_loss_std=0.111]