Source code for torchbearer.callbacks.timer

import time
from torchbearer.callbacks import Callback
import torchbearer


[docs]class TimerCallback(Callback): def __init__(self): """ Timer callback that aggregates timings for each stage of model execution """ super().__init__() self.t0 = time.time() self.time_dict = {}
[docs] def update_time(self, text, state): self.time_dict[text] = time.time() - self.t0 state[torchbearer.TIMINGS] = self.time_dict self.t0 = time.time()
[docs] def on_start(self, state): self.t0 = time.time() self.update_time('OnStart', state)
[docs] def on_start_training(self, state): super().on_start_training(state) self.update_time('OnStartTraining', state)
[docs] def on_start_epoch(self, state): super().on_start_epoch(state) self.update_time('OnStartEpoch', state)
[docs] def on_sample(self, state): super().on_sample(state) self.update_time('OnSample', state)
[docs] def on_forward(self, state): super().on_forward(state) self.update_time('OnForward', state)
[docs] def on_criterion(self, state): super().on_criterion(state) self.update_time('OnCriterion', state)
[docs] def on_backward(self, state): super().on_backward(state) self.update_time('OnBackward', state)
[docs] def on_step_training(self, state): super().on_step_training(state) self.update_time('OnStep', state)
[docs] def on_start_validation(self, state): super().on_start_validation(state) self.update_time('OnStartValidation', state)
[docs] def on_sample_validation(self, state): super().on_sample_validation(state) self.update_time('OnSampleValidation', state)
[docs] def on_forward_validation(self, state): super().on_forward_validation(state) self.update_time('OnForwardValidation', state)
[docs] def on_criterion_validation(self, state): super().on_criterion_validation(state) self.update_time('OnCriterionValidation', state)
[docs] def on_step_validation(self, state): super().on_step_validation(state) self.update_time('OnStepValidation', state)
[docs] def get_timings(self): return self.time_dict