Source code for torchbearer.metrics.timer

import time
from torchbearer.callbacks import Callback
import torchbearer
from torchbearer.metrics import Metric

ON_START_TRAINING = 'on_start_training'
ON_START_EPOCH = 'on_start_epoch'
ON_SAMPLE = 'on_sample'
ON_FORWARD = 'on_forward'
ON_CRITERION = 'on_criterion'
ON_BACKWARD = 'on_backward'
ON_STEP_TRAINING = 'on_step_training'
ON_START_VALIDATION = 'on_start_validation'
ON_SAMPLE_VALIDATION = 'on_sample_validation'
ON_FORWARD_VALIDATION = 'on_forward_validaiton'
ON_CRITERION_VALIDATION = 'on_criterion_validation'
ON_STEP_VALIDATION = 'on_step_validation'
TRAIN_TIME = 'train_time'
TOTAL_TIME = 'total_time'
VALIDATION_TIME = 'validation_time'


[docs]class TimerMetric(Callback, Metric): def __init__(self, time_keys=()): """ Timer callback that aggregates timings for each stage of model execution """ super(TimerMetric, self).__init__(name='timer') self.t0 = time.time() self.time_dict = {} # self.init_keys() self.batch_timer = _TimerMetric('t_batch') self.epoch_timer = _TimerMetric('t_epoch') self.train_timer = _TimerMetric('t_train') self.valid_timer = _TimerMetric('t_valid') self.total_timer = _TimerMetric('t_total') self.time_keys = time_keys self.added_callback = False
[docs] def update_time(self, text, metric, state): self.time_dict[text] = metric.process(state) state[torchbearer.TIMINGS] = self.time_dict
[docs] def process(self, *args): super().process(*args) d_out = {key: self.time_dict[key] for key in self.time_keys if key in self.time_dict} return d_out
[docs] def reset(self, state): super().reset(state) if not self.added_callback: state[torchbearer.CALLBACK_LIST].append([self]) self.added_callback = True
[docs] def on_start(self, state): self.t0 = time.time() self.batch_timer.reset(state) self.epoch_timer.reset(state) self.train_timer.reset(state) self.valid_timer.reset(state) self.total_timer.reset(state)
[docs] def on_start_training(self, state): super().on_start_training(state) self.update_time(ON_START_TRAINING, self.batch_timer, state) self.update_time(ON_START_TRAINING, self.train_timer, state)
[docs] def on_start_epoch(self, state): super().on_start_epoch(state) self.update_time(ON_START_EPOCH, self.epoch_timer, state)
[docs] def on_sample(self, state): super().on_sample(state) self.update_time(ON_SAMPLE, self.batch_timer, state)
[docs] def on_forward(self, state): super().on_forward(state) self.update_time(ON_FORWARD, self.batch_timer, state)
[docs] def on_criterion(self, state): super().on_criterion(state) self.update_time(ON_CRITERION, self.batch_timer, state)
[docs] def on_backward(self, state): super().on_backward(state) self.update_time(ON_BACKWARD, self.batch_timer, state)
[docs] def on_step_training(self, state): super().on_step_training(state) self.update_time(ON_STEP_TRAINING, self.batch_timer, state)
[docs] def on_start_validation(self, state): super().on_start_validation(state) self.update_time(ON_START_VALIDATION, self.batch_timer, state)
[docs] def on_sample_validation(self, state): super().on_sample_validation(state) self.update_time(ON_SAMPLE_VALIDATION, self.batch_timer, state)
[docs] def on_forward_validation(self, state): super().on_forward_validation(state) self.update_time(ON_FORWARD_VALIDATION, self.batch_timer, state)
[docs] def on_criterion_validation(self, state): super().on_criterion_validation(state) self.update_time(ON_CRITERION_VALIDATION, self.batch_timer, state)
[docs] def on_step_validation(self, state): super().on_step_validation(state) self.update_time(ON_STEP_VALIDATION, self.batch_timer, state)
[docs] def on_end_training(self, state): super().on_end_training(state) self.valid_timer.reset(state) self.batch_timer.reset(state) self.update_time(TRAIN_TIME, self.train_timer, state)
[docs] def on_end_epoch(self, state): super().on_end_epoch(state) self.batch_timer.reset(state) self.train_timer.reset(state)
[docs] def on_end(self, state): super().on_end(state) self.update_time(TOTAL_TIME, self.total_timer, state) print(self.time_dict)
[docs] def on_end_validation(self, state): super().on_end_validation(state) self.update_time(VALIDATION_TIME, self.valid_timer, state)
[docs] def get_timings(self): return self.time_dict
class _TimerMetric(Metric): def __init__(self, name): super().__init__(name) self.t = time.time() def process(self, *args): super().process(*args) dt = time.time() - self.t self.t = time.time() return dt def reset(self, state): super().reset(state) self.t = time.time()