Source code for torchbearer.callbacks.imaging.imaging

import torchbearer
from torchbearer import Callback
from torchbearer.callbacks.decorators import once_per_epoch

import torch


def _to_file(filename):
    from PIL import Image

    def handler(image, _):
        ndarr = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
        im = Image.fromarray(ndarr)
        im.save(filename)
    return handler


def _to_pyplot():
    import matplotlib.pyplot as plt

    def handler(image, _):
        ndarr = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
        plt.imshow(ndarr)
        plt.show()
    return handler


def _to_tensorboard(name='Image', log_dir='./logs', comment='torchbearer'):
    import torchbearer.callbacks.tensor_board as tb
    import os
    log_dir = os.path.join(log_dir, comment)

    def handler(image, state):
        writer = tb.get_writer(log_dir, _to_tensorboard)
        writer.add_image(name, image.clamp(0, 1), state[torchbearer.EPOCH])
        tb.close_writer(log_dir, _to_tensorboard)
    return handler


def _to_visdom(name='Image', log_dir='./logs', comment='torchbearer', visdom_params=None):
    import torchbearer.callbacks.tensor_board as tb
    import os
    log_dir = os.path.join(log_dir, comment)

    def handler(image, state):
        writer = tb.get_writer(log_dir, _to_visdom, visdom=True, visdom_params=visdom_params)
        writer.add_image(name + str(state[torchbearer.EPOCH]), image.clamp(0, 1), state[torchbearer.EPOCH])
        tb.close_writer(log_dir, _to_visdom)

    return handler


[docs]class ImagingCallback(Callback): """The :class:`ImagingCallback` provides a generic interface for callbacks which yield images that should be sent to a file, tensorboard, visdom etc. without needing bespoke code. This allows the user to easily define custom visualisations by only writing the code to produce the image. Args: transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. """ def __init__(self, transform=None): self._handlers = [] self.transform = (lambda img: img) if transform is None else transform
[docs] def on_batch(self, state): raise NotImplementedError
[docs] def process(self, state): img = self.on_batch(state) if img is not None: img = self.transform(img) for handler in self._handlers: handler(img, state)
[docs] def on_train(self): """Process this callback for training batches Returns: ImagingCallback: self """ _old_step_training = self.on_step_training def wrapper(state): _old_step_training(state) self.process(state) self.on_step_training = wrapper return self
[docs] def on_val(self): """Process this callback for validation batches Returns: ImagingCallback: self """ _old_step_validation = self.on_step_validation def wrapper(state): _old_step_validation(state) if state[torchbearer.DATA] is torchbearer.VALIDATION_DATA: self.process(state) self.on_step_validation = wrapper return self
[docs] def on_test(self): """Process this callback for test batches Returns: ImagingCallback: self """ _old_step_validation = self.on_step_validation def wrapper(state): _old_step_validation(state) if state[torchbearer.DATA] is torchbearer.TEST_DATA: self.process(state) self.on_step_validation = wrapper return self
[docs] def with_handler(self, handler): """Append the given output handler to the list of handlers Args: handler: A function of image and state which stores the given image in some way Returns: ImagingCallback: self """ self._handlers.append(handler) return self
[docs] def to_file(self, filename): """Send images from this callback to the given file Args: filename (str): the filename to store the image to Returns: ImagingCallback: self """ return self.with_handler(_to_file(filename))
[docs] def to_pyplot(self): """Show images from this callback with pyplot Returns: ImagingCallback: self """ return self.with_handler(_to_pyplot())
[docs] def to_state(self, key): """Put images from this callback in state with the given key Args: key (StateKey): The state key to use for the image Returns: ImagingCallback: self """ def handler(img, state): state[key] = img return self.with_handler(handler)
[docs] def to_tensorboard(self, name='Image', log_dir='./logs', comment='torchbearer'): """Direct images from this callback to tensorboard with the given parameters Args: name (str): The name of the image log_dir (str): The tensorboard log path for output comment (str): Descriptive comment to append to path Returns: ImagingCallback: self """ return self.with_handler(_to_tensorboard(name=name, log_dir=log_dir, comment=comment))
[docs] def to_visdom(self, name='Image', log_dir='./logs', comment='torchbearer', visdom_params=None): """Direct images from this callback to visdom with the given parameters Args: name (str): The name of the image log_dir (str): The visdom log path for output comment (str): Descriptive comment to append to path visdom_params (:class:`.VisdomParams`): Visdom parameter settings object, uses default if None Returns: ImagingCallback: self """ return self.with_handler(_to_visdom(name=name, log_dir=log_dir, comment=comment, visdom_params=visdom_params))
[docs]class CachingImagingCallback(ImagingCallback): """The :class:`CachingImagingCallback` is an :class:`ImagingCallback` which caches batches of images from the given state key up to the required amount before passing this along with state to the implementing class, once per epoch. Args: key (StateKey): The :class:`.StateKey` containing image data (tensor of size [b, c, w, h]) transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. num_images: The number of images to cache """ def __init__(self, key=torchbearer.INPUT, transform=None, num_images=16): super(CachingImagingCallback, self).__init__(transform=transform) self.key = key self.num_images = num_images self._data = None self._done = False
[docs] def on_cache(self, cache, state): """This method should be implemented by the overriding class to return an image from the cache. Args: cache (tensor): The collected cache of size (num_images, C, W, H) state (dict): The trial state dict Returns: The processed image """ raise NotImplementedError
[docs] def on_batch(self, state): if not self._done: data = state[self.key].detach() if self._data is None: remaining = self.num_images if self.num_images < data.size(0) else data.size(0) self._data = data[:remaining] else: remaining = self.num_images - self._data.size(0) if remaining > data.size(0): remaining = data.size(0) self._data = torch.cat((self._data, data[:remaining]), dim=0) if self._data.size(0) >= self.num_images: image = self.on_cache(self._data, state) self._done = True self._data = None return image
[docs] def on_end_epoch(self, state): super(CachingImagingCallback, self).on_end_epoch(state) self._done = False
[docs]class MakeGrid(CachingImagingCallback): """The :class:`MakeGrid` callback is a :class:`CachingImagingCallback` which calls make grid on the cache with the provided parameters. Args: key (StateKey): The :class:`.StateKey` containing image data (tensor of size [b, c, w, h]) transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. num_images: The number of images to cache nrow: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ padding: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ normalize: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ norm_range: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ scale_each: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ pad_value: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ """ def __init__(self, key=torchbearer.INPUT, transform=None, num_images=16, nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0): super(MakeGrid, self).__init__(transform=transform) self.key = key self.num_images = num_images self.nrow = nrow self.padding = padding self.normalize = normalize self.norm_range = norm_range self.scale_each = scale_each self.pad_value = pad_value
[docs] def on_cache(self, cache, state): import torchvision.utils as utils return utils.make_grid( cache, nrow=self.nrow, padding=self.padding, normalize=self.normalize, range=self.norm_range, scale_each=self.scale_each, pad_value=self.pad_value )
[docs]class FromState(ImagingCallback): """The :class:`FromState` callback is an :class:`ImagingCallback` which retrieves and image from state when called. The number of times the function is called can be controlled with a provided decorator (once_per_epoch, only_if etc.) Args: key (StateKey): The :class:`.StateKey` containing the image (tensor of size [c, w, h]) transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. decorator: A function which will be used to wrap the callback function. once_per_epoch by default """ def __init__(self, key, transform=None, decorator=once_per_epoch): super(FromState, self).__init__(transform=transform) self.key = key self.on_batch = decorator(self.on_batch)
[docs] def on_batch(self, state): try: return state[self.key] except KeyError: return None