Source code for torchbearer.state

import warnings

__keys__ = []


[docs]def state_key(key): """Computes and returns a non-conflicting key for the state dictionary when given a seed key :param key: The seed key - basis for new state key :type key: String :return: New state key :rtype: String """ return StateKey(key)
[docs]class StateKey: """ StateKey class that is a unique state key based of input string key :param key: Base key """ def __init__(self, key): super().__init__() self.key = self._gen_key_(key) def _gen_key_(self, key): if key in __keys__: count = 1 my_key = key + '_' + str(count) while my_key in __keys__: count += 1 my_key = key + '_' + str(count) key = my_key __keys__.append(key) return key def __repr__(self): return self.key def __str__(self): return self.key def __eq__(self, other): return self.key == str(other) def __hash__(self): return self.key.__hash__()
[docs]class State(dict): """ State dictionary that behaves like a python dict but accepts StateKeys """ def __init__(self): super().__init__()
[docs] def get_key(self, statekey): if isinstance(statekey, str): warnings.warn("State was accessed with a string: {}, generate keys with StateKey(str).".format(statekey), stacklevel=2) return statekey
def __getitem__(self, key): return super().__getitem__(self.get_key(key)) def __setitem__(self, key, val): super().__setitem__(self.get_key(key), val) def __delitem__(self, val): super().__delitem__(val) def __contains__(self, o: object) -> bool: return super().__contains__(self.get_key(o))
[docs] def update(self, d): new_dict = {} for key in d: new_dict[self.get_key(key)] = d[key] super().update(new_dict)
VERSION = state_key('torchbearer_version') MODEL = state_key('model') CRITERION = state_key('criterion') OPTIMIZER = state_key('optimizer') DEVICE = state_key('device') DATA_TYPE = state_key('dtype') METRIC_LIST = state_key('metric_list') METRICS = state_key('metrics') SELF = state_key('self') EPOCH = state_key('epoch') MAX_EPOCHS = state_key('max_epochs') GENERATOR = state_key('generator') ITERATOR = state_key('iterator') STEPS = state_key('steps') TRAIN_GENERATOR = state_key('train_generator') TRAIN_STEPS = state_key('train_steps') TRAIN_DATA = state_key('train_data') VALIDATION_GENERATOR = state_key('validation_generator') VALIDATION_STEPS = state_key('validation_steps') VALIDATION_DATA = state_key('validation_data') TEST_GENERATOR = state_key('test_generator') TEST_STEPS = state_key('test_steps') TEST_DATA = state_key('test_data') STOP_TRAINING = state_key('stop_training') Y_TRUE = state_key('y_true') Y_PRED = state_key('y_pred') X = state_key('x') SAMPLER = state_key('sampler') LOSS = state_key('loss') FINAL_PREDICTIONS = state_key('final_predictions') BATCH = state_key('t') TIMINGS = state_key('timings') CALLBACK_LIST = state_key('callback_list') HISTORY = state_key('history') BACKWARD_ARGS = state_key('backward_args') # Legacy VALIDATION_ITERATOR = 'validation_iterator' TRAIN_ITERATOR = 'train_iterator'