from torchbearer import Callback
[docs]
class CallbackList(Callback):
"""The :class:`CallbackList` class is a wrapper for a list of callbacks which acts as a single :class:`.Callback` and
internally calls each :class:`.Callback` in the given list in turn.
Args:
callback_list (list): The list of callbacks to be wrapped. If the list contains a :class:`CallbackList`, this
will be unwrapped.
"""
CALLBACK_STATES = 'callback_states'
CALLBACK_TYPES = 'callback_types'
def __init__(self, callback_list):
super(CallbackList, self).__init__()
self.callback_list = []
self.append(callback_list)
[docs]
def state_dict(self):
"""Get a dict containing all of the callback states.
Returns:
dict: A dict containing parameters and persistent buffers.
"""
state_dict = {
CallbackList.CALLBACK_STATES: [],
CallbackList.CALLBACK_TYPES: []
}
def to_state(callback):
state_dict[CallbackList.CALLBACK_STATES].append(callback.state_dict())
state_dict[CallbackList.CALLBACK_TYPES].append(callback.__class__)
self._for_list(to_state)
return state_dict
[docs]
def load_state_dict(self, state_dict):
"""Resume this callback list from the given state. Callbacks must be given in the same order for this to work.
Args:
state_dict (dict): The state dict to reload
Returns:
CallbackList: self
"""
t_iter = iter(state_dict[CallbackList.CALLBACK_TYPES])
s_iter = iter(state_dict[CallbackList.CALLBACK_STATES])
def from_state(callback):
if callback.__class__ == next(t_iter):
callback.load_state_dict(next(s_iter))
else:
import warnings
warnings.warn('Callback classes did not match, expected: ' + str([c.__name__ for c in state_dict[CallbackList.CALLBACK_TYPES]]))
self._for_list(from_state)
return self
def _for_list(self, function):
for callback in self.callback_list:
function(callback)
def __str__(self):
return str([str(c) for c in self.callback_list])
def __iter__(self):
return self.callback_list.__iter__()
def __copy__(self):
return CallbackList(self.callback_list)
[docs]
def copy(self):
return self.__copy__()
[docs]
def append(self, callback_list):
for callback in callback_list:
if isinstance(callback, CallbackList):
self.callback_list = self.callback_list + callback.callback_list
else:
self.callback_list.append(callback)
[docs]
def on_init(self, state):
"""Call on_init on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_init(state))
[docs]
def on_start(self, state):
"""Call on_start on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_start(state))
[docs]
def on_start_epoch(self, state):
"""Call on_start_epoch on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_start_epoch(state))
[docs]
def on_start_training(self, state):
"""Call on_start_training on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_start_training(state))
[docs]
def on_sample(self, state):
"""Call on_sample on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_sample(state))
[docs]
def on_forward(self, state):
"""Call on_forward on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_forward(state))
[docs]
def on_criterion(self, state):
"""Call on_criterion on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_criterion(state))
[docs]
def on_backward(self, state):
"""Call on_backward on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_backward(state))
[docs]
def on_step_training(self, state):
"""Call on_step_training on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_step_training(state))
[docs]
def on_end_training(self, state):
"""Call on_end_training on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_end_training(state))
[docs]
def on_start_validation(self, state):
"""Call on_start_validation on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_start_validation(state))
[docs]
def on_sample_validation(self, state):
"""Call on_sample_validation on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_sample_validation(state))
[docs]
def on_forward_validation(self, state):
"""Call on_forward_validation on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_forward_validation(state))
[docs]
def on_criterion_validation(self, state):
"""Call on_criterion_validation on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_criterion_validation(state))
[docs]
def on_step_validation(self, state):
"""Call on_step_validation on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_step_validation(state))
[docs]
def on_end_validation(self, state):
"""Call on_end_validation on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_end_validation(state))
[docs]
def on_end_epoch(self, state):
"""Call on_end_epoch on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_end_epoch(state))
[docs]
def on_checkpoint(self, state):
"""Call on_checkpoint on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_checkpoint(state))
[docs]
def on_end(self, state):
"""Call on_end on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_end(state))