import time
from torchbearer.callbacks import Callback
import torchbearer
[docs]class TimerCallback(Callback):
def __init__(self):
""" Timer callback that aggregates timings for each stage of model execution
"""
super().__init__()
self.t0 = time.time()
self.time_dict = {}
[docs] def update_time(self, text, state):
self.time_dict[text] = time.time() - self.t0
state[torchbearer.TIMINGS] = self.time_dict
self.t0 = time.time()
[docs] def on_start(self, state):
self.t0 = time.time()
self.update_time('OnStart', state)
[docs] def on_start_training(self, state):
super().on_start_training(state)
self.update_time('OnStartTraining', state)
[docs] def on_start_epoch(self, state):
super().on_start_epoch(state)
self.update_time('OnStartEpoch', state)
[docs] def on_sample(self, state):
super().on_sample(state)
self.update_time('OnSample', state)
[docs] def on_forward(self, state):
super().on_forward(state)
self.update_time('OnForward', state)
[docs] def on_criterion(self, state):
super().on_criterion(state)
self.update_time('OnCriterion', state)
[docs] def on_backward(self, state):
super().on_backward(state)
self.update_time('OnBackward', state)
[docs] def on_step_training(self, state):
super().on_step_training(state)
self.update_time('OnStep', state)
[docs] def on_start_validation(self, state):
super().on_start_validation(state)
self.update_time('OnStartValidation', state)
[docs] def on_sample_validation(self, state):
super().on_sample_validation(state)
self.update_time('OnSampleValidation', state)
[docs] def on_forward_validation(self, state):
super().on_forward_validation(state)
self.update_time('OnForwardValidation', state)
[docs] def on_criterion_validation(self, state):
super().on_criterion_validation(state)
self.update_time('OnCriterionValidation', state)
[docs] def on_step_validation(self, state):
super().on_step_validation(state)
self.update_time('OnStepValidation', state)
[docs] def get_timings(self):
return self.time_dict