Source code for torchbearer.callbacks.decorators

import sys
if sys.version_info[0] < 3:
    import inspect

    def count_args(fcn):
        return len(inspect.getargspec(fcn).args)
else:
    from inspect import signature

[docs] def count_args(fcn): return len(signature(fcn).parameters)
import types import torchbearer from torchbearer.callbacks import Callback
[docs]class LambdaCallback(Callback): def __init__(self, func): self.func = func
[docs] def on_lambda(self, state): return self.func(state)
[docs]def bind_to(target): def decorator(func): if isinstance(func, LambdaCallback): callback = func else: callback = LambdaCallback(func) setattr(callback, target.__name__, types.MethodType(lambda self, state: self.on_lambda(state), callback)) return callback return decorator
[docs]def on_start(func): """ The :func:`on_start` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_start` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_start` calling func """ return bind_to(Callback.on_start)(func)
[docs]def on_start_epoch(func): """ The :func:`on_start_epoch` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_start_epoch` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_start_epoch` calling func """ return bind_to(Callback.on_start_epoch)(func)
[docs]def on_start_training(func): """ The :func:`on_start_training` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_start_training` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_start_training` calling func """ return bind_to(Callback.on_start_training)(func)
[docs]def on_sample(func): """ The :func:`on_sample` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_sample` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_sample` calling func """ return bind_to(Callback.on_sample)(func)
[docs]def on_forward(func): """ The :func:`on_forward` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_forward` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_forward` calling func """ return bind_to(Callback.on_forward)(func)
[docs]def on_criterion(func): """ The :func:`on_criterion` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_criterion` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_criterion` calling func """ return bind_to(Callback.on_criterion)(func)
[docs]def on_backward(func): """ The :func:`on_backward` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_backward` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_backward` calling func """ return bind_to(Callback.on_backward)(func)
[docs]def on_step_training(func): """ The :func:`on_step_training` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_step_training` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_step_training` calling func """ return bind_to(Callback.on_step_training)(func)
[docs]def on_end_training(func): """ The :func:`on_end_training` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_end_training` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_end_training` calling func """ return bind_to(Callback.on_end_training)(func)
[docs]def on_end_epoch(func): """ The :func:`on_end_epoch` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_end_epoch` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_end_epoch` calling func """ return bind_to(Callback.on_end_epoch)(func)
[docs]def on_end(func): """ The :func:`on_end` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_end` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_end` calling func """ return bind_to(Callback.on_end)(func)
[docs]def on_start_validation(func): """ The :func:`on_start_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_start_validation` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_start_validation` calling func """ return bind_to(Callback.on_start_validation)(func)
[docs]def on_sample_validation(func): """ The :func:`on_sample_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_sample_validation` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_sample_validation` calling func """ return bind_to(Callback.on_sample_validation)(func)
[docs]def on_forward_validation(func): """ The :func:`on_forward_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_forward_validation` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_forward_validation` calling func """ return bind_to(Callback.on_forward_validation)(func)
[docs]def on_criterion_validation(func): """ The :func:`on_criterion_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_criterion_validation` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_criterion_validation` calling func """ return bind_to(Callback.on_criterion_validation)(func)
[docs]def on_end_validation(func): """ The :func:`on_end_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_end_validation` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_end_validation` calling func """ return bind_to(Callback.on_end_validation)(func)
[docs]def on_step_validation(func): """ The :func:`on_step_validation` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_step_validation` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_step_validation` calling func """ return bind_to(Callback.on_step_validation)(func)
[docs]def on_checkpoint(func): """ The :func:`on_checkpoint` decorator is used to initialise a :class:`.Callback` with :meth:`~.Callback.on_checkpoint` calling the decorated function Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback with :meth:`~.Callback.on_checkpoint` calling func """ return bind_to(Callback.on_checkpoint)(func)
[docs]def add_to_loss(func): """ The :func:`add_to_loss` decorator is used to initialise a :class:`.Callback` with the value returned from func being added to the loss Args: func (function): The function(state) to *decorate* Returns: Callback: Initialised callback which adds the returned value from func to the loss """ @on_criterion @on_criterion_validation def add_to_loss_callback(state): state[torchbearer.LOSS] = state[torchbearer.LOSS] + func(state) return add_to_loss_callback
[docs]def once(fcn): """ Decorator to fire a callback once in the lifetime of the callback. If the callback is a class method, each instance of the class will fire only once. For functions, only the first instance will fire (even if more than one function is present in the callback list). Args: fcn (function): the `torchbearer callback` function to decorate. Returns: the decorator """ def _once(self, _): try: return not self.__done__ except AttributeError: self.__done__ = True return True return only_if(_once)(fcn)
[docs]def once_per_epoch(fcn): """Decorator to fire a callback once (on the first call) in any given epoch. If the callback is a class method, each instance of the class will fire once per epoch. For functions, only the first instance will fire (even if more than one function is present in the callback list). .. note:: The decorated callback may exhibit unusual behaviour if it is reused Args: fcn (function): the `torchbearer callback` function to decorate. Returns: the decorator """ def ope(self, state): try: if state[torchbearer.EPOCH] != self.__last_epoch__: self.__last_epoch__ = state[torchbearer.EPOCH] return True return False except AttributeError: self.__last_epoch__ = state[torchbearer.EPOCH] return True return only_if(ope)(fcn)
[docs]def only_if(condition_expr): """ Decorator to fire a callback only if the given conditional expression function returns True. The conditional expression can be a function of state or self and state. If the decorated function is not a class method (i.e. it does not take state) the decorated function will be passed instead. This enables the storing of temporary variables. Args: condition_expr (function(self, state) or function(self)): a function/lambda which takes state and optionally\ self that must evaluate to true for the decorated `torchbearer callback` to be called. The\ `state` object passed to the callback will be passed as an argument to the condition function. Returns: the decorator """ def condition_decorator(fcn): if isinstance(fcn, LambdaCallback): fcn.func = condition_decorator(fcn.func) return fcn else: count = count_args(fcn) if count == 2 and not hasattr(fcn, '__self__'): # Assume Class method def decfcn(o, state): try: res = condition_expr(o, state) except TypeError: res = condition_expr(state) if res: return fcn(o, state) else: # Assume function of state def id_fcn(state): return fcn(state) # Hack to allow setting attributes of bound methods def decfcn(state): try: res = condition_expr(id_fcn, state) except TypeError: res = condition_expr(state) if res: return id_fcn(state) return decfcn return condition_decorator