Source code for torchbearer.callbacks.imaging.imaging

import torchbearer
from torchbearer import Callback

import torch


def _to_file(filename):
    from PIL import Image

    def handler(image, index, model_state):
        state = {}
        state.update(model_state)
        state.update(model_state[torchbearer.METRICS])

        string_state = {str(key): state[key] for key in state.keys()}

        ndarr = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
        im = Image.fromarray(ndarr)
        im.save(filename.format(index=str(index), **string_state))

    return handler


def _to_pyplot(title=None, show=True):
    import matplotlib.pyplot as plt

    def handler(image, index, _):
        ndarr = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
        plt.imshow(ndarr)
        if title is not None:
            plt.title(title.format(index=str(index)))
        plt.axis('off')

        if show:
            plt.show()

    return handler


def _to_tensorboard(name='Image', log_dir='./logs', comment='torchbearer'):
    import torchbearer.callbacks.tensor_board as tb
    import os
    log_dir = os.path.join(log_dir, comment)

    def handler(image, index, state):
        writer = tb.get_writer(log_dir, _to_tensorboard)
        writer.add_image(name.format(index=str(index)), image.clamp(0, 1), state[torchbearer.EPOCH])
        tb.close_writer(log_dir, _to_tensorboard)

    return handler


def _to_visdom(name='Image', log_dir='./logs', comment='torchbearer', visdom_params=None):
    import torchbearer.callbacks.tensor_board as tb
    import os
    log_dir = os.path.join(log_dir, comment)

    def handler(image, index, state):
        writer = tb.get_writer(log_dir, _to_visdom, visdom=True, visdom_params=visdom_params)
        writer.add_image(name.format(index=str(index)) + '_' + str(state[torchbearer.EPOCH]), image.clamp(0, 1), state[torchbearer.EPOCH])
        tb.close_writer(log_dir, _to_visdom)

    return handler


def _cache_images(num_images):
    cache = {'images': None, 'done': False}

    def decorator(fun):
        def step(state):
            if state[torchbearer.BATCH] == 0:
                cache['done'] = False

            if not cache['done']:
                data = fun(state)

                if cache['images'] is None:
                    remaining = num_images if num_images < data.size(0) else data.size(0)

                    cache['images'] = data[:remaining]
                else:
                    remaining = num_images - cache['images'].size(0)

                    if remaining > data.size(0):
                        remaining = data.size(0)

                    cache['images'] = torch.cat((cache['images'], data[:remaining]), dim=0)

                if cache['images'].size(0) >= num_images:
                    res = cache['images']
                    cache['done'] = True
                    cache['images'] = None
                    return res
        return step
    return decorator


[docs] class ImagingCallback(Callback): """The :class:`ImagingCallback` provides a generic interface for callbacks which yield images that should be sent to a file, tensorboard, visdom etc. without needing bespoke code. This allows the user to easily define custom visualisations by only writing the code to produce the image. Args: transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. """ def __init__(self, transform=None): self._handlers = [] self.transform = (lambda img: img) if transform is None else transform
[docs] def on_batch(self, state): raise NotImplementedError
[docs] def process(self, state): img = self.on_batch(state) if img is not None: img = self.transform(img) for handler, index in self._handlers: if img.dim() == 3: img = img.unsqueeze(0) rng = range(img.size(0)) if index is None else index try: for i in rng: handler(img[i], i, state) except TypeError: handler(img[rng], rng, state)
[docs] def on_train(self): """Process this callback for training batches Returns: ImagingCallback: self """ _old_step_training = self.on_step_training def wrapper(state): _old_step_training(state) self.process(state) self.on_step_training = wrapper return self
[docs] def on_val(self): """Process this callback for validation batches Returns: ImagingCallback: self """ _old_step_validation = self.on_step_validation def wrapper(state): _old_step_validation(state) if state[torchbearer.DATA] is torchbearer.VALIDATION_DATA: self.process(state) self.on_step_validation = wrapper return self
[docs] def on_test(self): """Process this callback for test batches Returns: ImagingCallback: self """ _old_step_validation = self.on_step_validation def wrapper(state): _old_step_validation(state) if state[torchbearer.DATA] is torchbearer.TEST_DATA: self.process(state) self.on_step_validation = wrapper return self
[docs] def with_handler(self, handler, index=None): """Append the given output handler to the list of handlers Args: handler: A function of image and state which stores the given image in some way index (int or list or None): If not None, only apply the handler on this index / list of indices Returns: ImagingCallback: self """ self._handlers.append((handler, index)) return self
[docs] def to_file(self, filename, index=None): """Send images from this callback to the given file Args: filename (str): The filename to store the image to index (int or list or None): If not None, only apply the handler on this index / list of indices Returns: ImagingCallback: self """ return self.with_handler(_to_file(filename), index=index)
[docs] def to_pyplot(self, title=None, show=True, index=None): """Show images from this callback with pyplot Args: title (str or None): If not None, plt.title will be called with the given string show (bool): If True (default), show will be called after each image is plotted index (int or list or None): If not None, only apply the handler on this index / list of indices Returns: ImagingCallback: self """ return self.with_handler(_to_pyplot(title=title, show=show), index=index)
[docs] def to_state(self, keys, index=None): """Put images from this callback in state with the given key Args: keys (StateKey or list[StateKey]): The state key or keys to use for the images index (int or list or None): If not None, only apply the handler on this index / list of indices Returns: ImagingCallback: self """ if str(keys) == keys: keys = [keys] try: _ = (key for key in keys) except TypeError: keys = [keys] def handler(img, i, state): state[keys[i]] = img return self.with_handler(handler, index=index)
[docs] def to_tensorboard(self, name='Image', log_dir='./logs', comment='torchbearer', index=None): """Direct images from this callback to tensorboard with the given parameters Args: name (str): The name of the image log_dir (str): The tensorboard log path for output comment (str): Descriptive comment to append to path index (int or list or None): if not None, only apply the handler on this index / list of indices Returns: ImagingCallback: self """ return self.with_handler(_to_tensorboard(name=name, log_dir=log_dir, comment=comment), index=index)
[docs] def to_visdom(self, name='Image', log_dir='./logs', comment='torchbearer', visdom_params=None, index=None): """Direct images from this callback to visdom with the given parameters Args: name (str): The name of the image log_dir (str): The visdom log path for output comment (str): Descriptive comment to append to path visdom_params (:class:`.VisdomParams`): Visdom parameter settings object, uses default if None index (int or list or None): if not None, only apply the handler on this index / list of indices Returns: ImagingCallback: self """ return self.with_handler(_to_visdom(name=name, log_dir=log_dir, comment=comment, visdom_params=visdom_params), index=index)
[docs] def cache(self, num_images): """Cache images **before** they are passed to handlers. Once per epoch, a single cache will be returned, containing the first `num_images` to be returned. Args: num_images (int): The number of images to cache Returns: ImagingCallback: self """ self.on_batch = _cache_images(num_images)(self.on_batch) return self
[docs] def make_grid(self, nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0): """Use `torchvision.utils.make_grid` to make a grid of the images being returned by this callback. Recommended for use alongside `cache`. Args: nrow: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ padding: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ normalize: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ norm_range: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ scale_each: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ pad_value: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ Returns: ImagingCallback: self """ import torchvision.utils as utils def decorator(func): def wrapper(state): cache = func(state) if cache is not None: return utils.make_grid(cache, nrow=nrow, padding=padding, normalize=normalize, range=norm_range, scale_each=scale_each, pad_value=pad_value) return wrapper self.on_batch = decorator(self.on_batch) return self
[docs] class FromState(ImagingCallback): """The :class:`FromState` callback is an :class:`ImagingCallback` which retrieves and image from state when called. The number of times the function is called can be controlled with a provided decorator (once_per_epoch, only_if etc.) Args: key (StateKey): The :class:`.StateKey` containing the image (tensor of size [c, w, h]) transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. decorator: A function which will be used to wrap the callback function. once_per_epoch by default """ def __init__(self, key, transform=None, decorator=None): super(FromState, self).__init__(transform=transform) self.key = key if decorator is not None: self.on_batch = decorator(self.on_batch)
[docs] def on_batch(self, state): try: return state[self.key] except KeyError: return None
[docs] class CachingImagingCallback(FromState): """The :class:`CachingImagingCallback` is an :class:`ImagingCallback` which caches batches of images from the given state key up to the required amount before passing this along with state to the implementing class, once per epoch. Args: key (StateKey): The :class:`.StateKey` containing image data (tensor of size [b, c, w, h]) transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. num_images: The number of images to cache """ def __init__(self, key=torchbearer.INPUT, transform=None, num_images=16): super(CachingImagingCallback, self).__init__(key=key, transform=transform, decorator=_cache_images(num_images)) def decorator(func): def wrapper(state): res = func(state) if res is not None: return self.on_cache(res, state) return wrapper self.on_batch = decorator(self.on_batch)
[docs] def on_cache(self, cache, state): """This method should be implemented by the overriding class to return an image from the cache. Args: cache (tensor): The collected cache of size (num_images, C, W, H) state (dict): The trial state dict Returns: The processed image """ raise NotImplementedError
[docs] class MakeGrid(CachingImagingCallback): """The :class:`MakeGrid` callback is a :class:`CachingImagingCallback` which calls make grid on the cache with the provided parameters. Args: key (StateKey): The :class:`.StateKey` containing image data (tensor of size [b, c, w, h]) transform (callable, optional): A function/transform that takes in a Tensor and returns a transformed version. This will be applied to the image before it is sent to output. num_images: The number of images to cache nrow: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ padding: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ normalize: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ norm_range: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ scale_each: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ pad_value: See `torchvision.utils.make_grid <https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid>`_ """ def __init__(self, key=torchbearer.INPUT, transform=None, num_images=16, nrow=8, padding=2, normalize=False, norm_range=None, scale_each=False, pad_value=0): super(MakeGrid, self).__init__(transform=transform, num_images=num_images, key=key) self.key = key self.num_images = num_images self.nrow = nrow self.padding = padding self.normalize = normalize self.norm_range = norm_range self.scale_each = scale_each self.pad_value = pad_value
[docs] def on_cache(self, cache, state): import torchvision.utils as utils return utils.make_grid( cache, nrow=self.nrow, padding=self.padding, normalize=self.normalize, range=self.norm_range, scale_each=self.scale_each, pad_value=self.pad_value )