import copy
import torchvision.utils as utils
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import torchbearer
from torchbearer.callbacks import Callback
import os
import torch
[docs]class TensorBoard(Callback):
"""The TensorBoard callback is used to write metric graphs to tensorboard. Requires the TensorboardX library for
python.
"""
def __init__(self, log_dir='./logs',
write_graph=True,
write_batch_metrics=False,
batch_step_size=10,
write_epoch_metrics=True,
comment='torchbearer'):
"""TensorBoard callback which writes metrics to the given log directory.
:param log_dir: The tensorboard log path for output
:type log_dir: str
:param write_graph: If True, the model graph will be written using the TensorboardX library
:type write_graph: bool
:param write_batch_metrics: If True, batch metrics will be written
:type write_batch_metrics: bool
:param batch_step_size: The step size to use when writing batch metrics, make this larger to reduce latency
:type batch_step_size: int
:param write_epoch_metrics: If True, metrics from the end of the epoch will be written
:type write_epoch_metrics: True
:param comment: Descriptive comment to append to path
:type comment: str
"""
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.write_graph = write_graph
self.write_batch_metrics = write_batch_metrics
self.batch_step_size = batch_step_size
self.write_epoch_metrics = write_epoch_metrics
self.comment = comment
if self.write_graph:
def handle_graph(state):
dummy = torch.rand(state[torchbearer.X].size(), requires_grad=False)
model = copy.deepcopy(state[torchbearer.MODEL]).to('cpu')
self._writer.add_graph(model, (dummy, ))
self._handle_graph = lambda _: ...
self._handle_graph = handle_graph
else:
self._handle_graph = lambda _: ...
self._writer = None
self._batch_writer = None
[docs] def on_start(self, state):
self.log_dir = os.path.join(self.log_dir, state[torchbearer.MODEL].__class__.__name__ + '_' + self.comment)
self._writer = SummaryWriter(log_dir=self.log_dir)
[docs] def on_start_epoch(self, state):
if self.write_batch_metrics:
log_dir = os.path.join(self.log_dir, 'epoch-' + str(state[torchbearer.EPOCH]))
self._batch_writer = SummaryWriter(log_dir=log_dir)
[docs] def on_end(self, state):
self._writer.close()
[docs] def on_sample(self, state):
self._handle_graph(state)
[docs] def on_step_training(self, state):
if self.write_batch_metrics and state[torchbearer.BATCH] % self.batch_step_size == 0:
for metric in state[torchbearer.METRICS]:
self._batch_writer.add_scalar('batch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.BATCH])
[docs] def on_step_validation(self, state):
if self.write_batch_metrics and state[torchbearer.BATCH] % self.batch_step_size == 0:
for metric in state[torchbearer.METRICS]:
self._batch_writer.add_scalar('batch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.BATCH])
[docs] def on_end_epoch(self, state):
if self.write_batch_metrics:
self._batch_writer.close()
if self.write_epoch_metrics:
for metric in state[torchbearer.METRICS]:
self._writer.add_scalar('epoch/' + metric, state[torchbearer.METRICS][metric], state[torchbearer.EPOCH])
[docs]class TensorBoardImages(Callback):
"""The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the
TensorboardX library and torchvision.utils.make_grid
"""
def __init__(self, log_dir='./logs',
comment='torchbearer',
name='Image',
key=torchbearer.Y_PRED,
write_each_epoch=True,
num_images=16,
nrow=8,
padding=2,
normalize=False,
range=None,
scale_each=False,
pad_value=0):
"""Create TensorBoardImages callback which writes images from the given key to the given path. Full name of
image sub directory will be model name + _ + comment.
:param log_dir: The tensorboard log path for output
:type log_dir: str
:param comment: Descriptive comment to append to path
:type comment: str
:param name: The name of the image
:type name: str
:param key: The key in state containing image data (tensor of size [c, w, h] or [b, c, w, h])
:type key: str
:param write_each_epoch: If True, write data on every epoch, else write only for the first epoch.
:type write_each_epoch: bool
:param num_images: The number of images to write
:type num_images: int
:param nrow: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param padding: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param normalize: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param range: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param scale_each: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param pad_value: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
"""
self.log_dir = log_dir
self.comment = comment
self.name = name
self.key = key
self.write_each_epoch = write_each_epoch
self.num_images = num_images
self.nrow = nrow
self.padding = padding
self.normalize = normalize
self.range = range
self.scale_each = scale_each
self.pad_value = pad_value
self._writer = None
self._data = None
self.done = False
[docs] def on_start(self, state):
log_dir = os.path.join(self.log_dir, state[torchbearer.MODEL].__class__.__name__ + '_' + self.comment)
self._writer = SummaryWriter(log_dir=log_dir)
[docs] def on_step_validation(self, state):
if not self.done:
data = state[self.key].clone()
if len(data.size()) == 3:
data = data.unsqueeze(1)
if self._data is None:
remaining = self.num_images if self.num_images < data.size(0) else data.size(0)
self._data = data[:remaining].to('cpu')
else:
remaining = self.num_images - self._data.size(0)
if remaining > data.size(0):
remaining = data.size(0)
self._data = torch.cat((self._data, data[:remaining].to('cpu')), dim=0)
if self._data.size(0) >= self.num_images:
image = utils.make_grid(
self._data,
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
range=self.range,
scale_each=self.scale_each,
pad_value=self.pad_value
)
self._writer.add_image(self.name, image, state[torchbearer.EPOCH])
self.done = True
self._data = None
[docs] def on_end_epoch(self, state):
if self.write_each_epoch:
self.done = False
[docs] def on_end(self, state):
self._writer.close()
[docs]class TensorBoardProjector(Callback):
"""The TensorBoardProjector callback is used to write images from the validation pass to Tensorboard using the
TensorboardX library.
"""
def __init__(self, log_dir='./logs',
comment='torchbearer',
num_images=100,
avg_pool_size=1,
avg_data_channels=True,
write_data=True,
write_features=True,
features_key=torchbearer.Y_PRED):
"""Construct a TensorBoardProjector callback which writes images to the given directory and, if required,
associated features.
:param log_dir: The tensorboard log path for output
:type log_dir: str
:param comment: Descriptive comment to append to path
:type comment: str
:param num_images: The number of images to write
:type num_images: int
:param avg_pool_size: Size of the average pool to perform on the image. This is recommended to reduce the
overall image sizes and improve latency
:type avg_pool_size: int
:param avg_data_channels: If True, the image data will be averaged in the channel dimension
:type avg_data_channels: bool
:param write_data: If True, the raw data will be written as an embedding
:type write_data: bool
:param write_features: If True, the image features will be written as an embedding
:type write_features: bool
:param features_key: The key in state to use for the embedding. Typically model output but can be used to show
features from any layer of the model.
:type features_key: str
"""
self.log_dir = log_dir
self.comment = comment
self.num_images = num_images
self.avg_pool_size = avg_pool_size
self.avg_data_channels = avg_data_channels
self.write_data = write_data
self.write_features = write_features
self.features_key = features_key
self._writer = None
self.done = False
[docs] def on_start(self, state):
log_dir = os.path.join(self.log_dir, state[torchbearer.MODEL].__class__.__name__ + '_' + self.comment)
self._writer = SummaryWriter(log_dir=log_dir)
[docs] def on_step_validation(self, state):
if not self.done:
x = state[torchbearer.X].data.clone()
if len(x.size()) == 3:
x = x.unsqueeze(1)
x = F.avg_pool2d(x, self.avg_pool_size).data
data = None
if state[torchbearer.EPOCH] == 0 and self.write_data:
if self.avg_data_channels:
data = torch.mean(x, 1)
else:
data = x
data = data.view(data.size(0), -1)
feature = None
if self.write_features:
feature = state[self.features_key].data.clone()
feature = feature.view(feature.size(0), -1)
label = state[torchbearer.Y_TRUE].data.clone()
if state[torchbearer.BATCH] == 0:
remaining = self.num_images if self.num_images < label.size(0) else label.size(0)
self._images = x[:remaining].to('cpu')
self._labels = label[:remaining].to('cpu')
if data is not None:
self._data = data[:remaining].to('cpu')
if feature is not None:
self._features = feature[:remaining].to('cpu')
else:
remaining = self.num_images - self._labels.size(0)
if remaining > label.size(0):
remaining = label.size(0)
self._images = torch.cat((self._images, x[:remaining].to('cpu')), dim=0)
self._labels = torch.cat((self._labels, label[:remaining].to('cpu')), dim=0)
if data is not None:
self._data = torch.cat((self._data, data[:remaining].to('cpu')), dim=0)
if feature is not None:
self._features = torch.cat((self._features, feature[:remaining].to('cpu')), dim=0)
if self._labels.size(0) >= self.num_images:
if state[torchbearer.EPOCH] == 0 and self.write_data:
self._writer.add_embedding(self._data, metadata=self._labels, label_img=self._images, tag='data', global_step=-1)
if self.write_features:
self._writer.add_embedding(self._features, metadata=self._labels, label_img=self._images, tag='features', global_step=state[torchbearer.EPOCH])
self.done = True
[docs] def on_end_epoch(self, state):
if self.write_features:
self.done = False
[docs] def on_end(self, state):
self._writer.close()