torchbearer.metrics

Base Classes

The base metric classes exist to enable complex data flow requirements between metrics. All metrics are either instances of Metric or MetricFactory. These can then be collected in a MetricList or a MetricTree. The MetricList simply aggregates calls from a list of metrics, whereas the MetricTree will pass data from its root metric to each child and collect the outputs. This enables complex running metrics and statistics, without needing to compute the underlying values more than once. Typically, constructions of this kind should be handled using the decorator API.

class torchbearer.metrics.metrics.AdvancedMetric(name)[source]

The AdvancedMetric class is a metric which provides different process methods for training and validation. This enables running metrics which do not output intermediate steps during validation.

Parameters:name (str) – The name of the metric.
eval(data_key=None)[source]

Put the metric in eval mode.

Parameters:data_key (Optional(torchbearer.StateKey)) – The torchbearer data_key, if used
process(*args)[source]

Depending on the current mode, return the result of either ‘process_train’ or ‘process_validate’.

Returns:The metric value.
process_final(*args)[source]

Depending on the current mode, return the result of either ‘process_final_train’ or ‘process_final_validate’.

Returns:The final metric value.
process_final_train(*args)[source]

Process the given state and return the final metric value for a training iteration.

Returns:The final metric value for a training iteration.
process_final_validate(*args)[source]

Process the given state and return the final metric value for a validation iteration.

Returns:The final metric value for a validation iteration.
process_train(*args)[source]

Process the given state and return the metric value for a training iteration.

Returns:The metric value for a training iteration.
process_validate(*args)[source]

Process the given state and return the metric value for a validation iteration.

Returns:The metric value for a validation iteration.
train()[source]

Put the metric in train mode.

class torchbearer.metrics.metrics.Metric(name)[source]

Base metric class. Process will be called on each batch, process-final at the end of each epoch. The metric contract allows for metrics to take any args but not kwargs. The initial metric call will be given state, however, subsequent metrics can pass any values desired.

Note

All metrics must extend this class.

Parameters:name (str) – The name of the metric
eval(data_key=None)[source]

Put the metric in eval mode during model validation.

process = <MagicMock name='mock()()' id='139745137640336'>
process_final = <MagicMock name='mock()()' id='139745137673608'>
reset(state)[source]

Reset the metric, called before the start of an epoch.

Parameters:state – The current state dict of the Model.
train()[source]

Put the metric in train mode during model training.

class torchbearer.metrics.metrics.MetricList(metric_list)[source]

The MetricList class is a wrapper for a list of metrics which acts as a single metric and produces a dictionary of outputs.

Parameters:metric_list (list) – The list of metrics to be wrapped. If the list contains a MetricList, this will be unwrapped. Any strings in the list will be retrieved from metrics.DEFAULT_METRICS.
eval(data_key=None)[source]

Put each metric in eval mode

process(*args)[source]

Process each metric an wrap in a dictionary which maps metric names to values.

Returns:dict[str,any] – A dictionary which maps metric names to values.
process_final(*args)[source]

Process each metric an wrap in a dictionary which maps metric names to values.

Returns:dict[str,any] – A dictionary which maps metric names to values.
reset(state)[source]

Reset each metric with the given state.

Parameters:state – The current state dict of the Model.
train()[source]

Put each metric in train mode.

class torchbearer.metrics.metrics.MetricTree(metric)[source]

A tree structure which has a node Metric and some children. Upon execution, the node is called with the input and its output is passed to each of the children. A dict is updated with the results.

Parameters:metric (Metric) – The metric to act as the root node of the tree / subtree
add_child(child)[source]

Add a child to this node of the tree

Parameters:child (Metric) – The child to add
Returns:None
eval(data_key=None)[source]

Put the metric in eval mode during model validation.

process(*args)[source]

Process this node and then pass the output to each child.

Returns:A dict containing all results from the children
process_final(*args)[source]

Process this node and then pass the output to each child.

Returns:A dict containing all results from the children
reset(state)[source]

Reset the metric, called before the start of an epoch.

Parameters:state – The current state dict of the Model.
train()[source]

Put the metric in train mode during model training.

torchbearer.metrics.metrics.add_default(key, metric, *args, **kwargs)[source]
torchbearer.metrics.metrics.get_default(key)[source]
torchbearer.metrics.metrics.no_grad()[source]

Decorators - The Decorator API

The decorator API is the core way to interact with metrics in torchbearer. All of the classes and functionality handled here can be reproduced by manually interacting with the classes if necessary. Broadly speaking, the decorator API is used to construct a MetricFactory which will build a MetricTree that handles data flow between instances of Mean, RunningMean, Std etc.

torchbearer.metrics.decorators.default_for_key(key, *args, **kwargs)[source]

The default_for_key() decorator will register the given metric in the global metric dict (metrics.DEFAULT_METRICS) so that it can be referenced by name in instances of MetricList such as in the list given to the torchbearer.Model.

Example:

@default_for_key('acc')
class CategoricalAccuracy(metrics.BatchLambda):
    ...
Parameters:
  • key (str) – The key to use when referencing the metric
  • args – Any args to pass to the underlying metric when constructed
  • kwargs – Any keyword args to pass to the underlying metric when constructed
torchbearer.metrics.decorators.lambda_metric(name, on_epoch=False)[source]

The lambda_metric() decorator is used to convert a lambda function y_pred, y_true into a Metric instance. This can be used as in the following example:

@metrics.lambda_metric('my_metric')
def my_metric(y_pred, y_true):
    ... # Calculate some metric

model = Model(metrics=[my_metric])
Parameters:
  • name – The name of the metric (e.g. ‘loss’)
  • on_epoch – If True the metric will be an instance of EpochLambda instead of BatchLambda
Returns:

A decorator which replaces a function with a Metric

torchbearer.metrics.decorators.mean(clazz)[source]

The mean() decorator is used to add a Mean to the MetricTree which will will output a mean value at the end of each epoch. At build time, if the inner class is not a MetricTree, one will be created. The Mean will also be wrapped in a ToDict for simplicity.

Example:

>>> import torch
>>> from torchbearer import metrics

>>> @metrics.mean
... @metrics.lambda_metric('my_metric')
... def metric(y_pred, y_true):
...     return y_pred + y_true
...
>>> metric.reset({})
>>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
{}
>>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
{}
>>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8
{}
>>> metric.process_final()
{'my_metric': 6.0}
Parameters:clazz – The class to decorate
Returns:A MetricTree with a Mean appended or a wrapper class that extends MetricTree
torchbearer.metrics.decorators.running_mean(clazz=None, batch_size=50, step_size=10)[source]

The running_mean() decorator is used to add a RunningMean to the MetricTree. If the inner class is not a MetricTree then one will be created. The RunningMean will be wrapped in a ToDict (with ‘running_’ prepended to the name) for simplicity.

Note

The decorator function does not need to be called if not desired, both: @running_mean and @running_mean() are acceptable.

Example:

>>> import torch
>>> from torchbearer import metrics

>>> @metrics.running_mean(step_size=2) # Update every 2 steps
... @metrics.lambda_metric('my_metric')
... def metric(y_pred, y_true):
...     return y_pred + y_true
...
>>> metric.reset({})
>>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
{'running_my_metric': 4.0}
>>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
{'running_my_metric': 4.0}
>>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8, triggers update
{'running_my_metric': 6.0}
Parameters:
Returns:

decorator or MetricTree instance or wrapper

torchbearer.metrics.decorators.std(clazz)[source]

The std() decorator is used to add a Std to the MetricTree which will will output a population standard deviation value at the end of each epoch. At build time, if the inner class is not a MetricTree, one will be created. The Std will also be wrapped in a ToDict (with ‘_std’ appended) for simplicity.

Example:

>>> import torch
>>> from torchbearer import metrics

>>> @metrics.std
... @metrics.lambda_metric('my_metric')
... def metric(y_pred, y_true):
...     return y_pred + y_true
...
>>> metric.reset({})
>>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
{}
>>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
{}
>>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8
{}
>>> '%.4f' % metric.process_final()['my_metric_std']
'1.6330'
Parameters:clazz – The class to decorate
Returns:A MetricTree with a Std appended or a wrapper class that extends MetricTree
torchbearer.metrics.decorators.to_dict(clazz)[source]

The to_dict() decorator is used to wrap either a Metric class or a Metric instance with a ToDict instance. The result is that future output will be wrapped in a dict[name, value].

Example:

>>> from torchbearer import metrics

>>> @metrics.lambda_metric('my_metric')
... def my_metric(y_pred, y_true):
...     return y_pred + y_true
...
>>> my_metric.process({'y_pred':4, 'y_true':5})
9

>>> @metrics.to_dict
... @metrics.lambda_metric('my_metric')
... def my_metric(y_pred, y_true):
...     return y_pred + y_true
...
>>> my_metric.process({'y_pred':4, 'y_true':5})
{'my_metric': 9}
Parameters:clazz – The class to decorate
Returns:A ToDict instance or a ToDict wrapper of the given class

Metric Wrappers

Metric wrappers are classes which wrap instances of Metric or, in the case of EpochLambda and BatchLambda, functions. Typically, these should not be used directly (although this is entirely possible), but via the decorator API.

class torchbearer.metrics.wrappers.BatchLambda(name, metric_function)[source]

A metric which returns the output of the given function on each batch.

Parameters:
  • name (str) – The name of the metric.
  • metric_function – A metric function(‘y_pred’, ‘y_true’) to wrap.
process(*args)[source]

Return the output of the wrapped function.

Parameters:args (dict) – The torchbearer.Model state.
Returns:The value of the metric function(‘y_pred’, ‘y_true’).
class torchbearer.metrics.wrappers.EpochLambda(name, metric_function, running=True, step_size=50)[source]

A metric wrapper which computes the given function for concatenated values of ‘y_true’ and ‘y_pred’ each epoch. Can be used as a running metric which computes the function for batches of outputs with a given step size during training.

Parameters:
  • name (str) – The name of the metric.
  • metric_function – The function(‘y_pred’, ‘y_true’) to use as the metric.
  • running (bool) – True if this should act as a running metric.
  • step_size (int) – Step size to use between calls if running=True.
process_final_train(*args)[source]

Evaluate the function with the aggregated outputs.

Returns:The result of the function.
process_final_validate(*args)[source]

Evaluate the function with the aggregated outputs.

Returns:The result of the function.
process_train(*args)[source]

Concatenate the ‘y_true’ and ‘y_pred’ from the state along the 0 dimension. If this is a running metric, evaluates the function every number of steps.

Parameters:args (dict) – The torchbearer.Model state.
Returns:The current running result.
process_validate(*args)[source]

During validation, just concatenate ‘y_true’ and y_pred’.

Parameters:args (dict) – The torchbearer.Model state.
reset(state)[source]

Reset the ‘y_true’ and ‘y_pred’ caches.

Parameters:state (dict) – The torchbearer.Model state.
class torchbearer.metrics.wrappers.ToDict(metric)[source]

The ToDict class is an AdvancedMetric which will put output from the inner Metric in a dict (mapping metric name to value) before returning. When in eval mode, ‘val_’ will be prepended to the metric name.

Example:

>>> from torchbearer import metrics

>>> @metrics.lambda_metric('my_metric')
... def my_metric(y_pred, y_true):
...     return y_pred + y_true
...
>>> metric = metrics.ToDict(my_metric().build())
>>> metric.process({'y_pred': 4, 'y_true': 5})
{'my_metric': 9}
>>> metric.eval()
>>> metric.process({'y_pred': 4, 'y_true': 5})
{'val_my_metric': 9}
Parameters:metric (metrics.Metric) – The Metric instance to wrap.
eval(data_key=None)[source]

Put the metric in eval mode.

Parameters:data_key (Optional(torchbearer.StateKey)) – The torchbearer data_key, if used
process_final_train(*args)[source]

Process the given state and return the final metric value for a training iteration.

Returns:The final metric value for a training iteration.
process_final_validate(*args)[source]

Process the given state and return the final metric value for a validation iteration.

Returns:The final metric value for a validation iteration.
process_train(*args)[source]

Process the given state and return the metric value for a training iteration.

Returns:The metric value for a training iteration.
process_validate(*args)[source]

Process the given state and return the metric value for a validation iteration.

Returns:The metric value for a validation iteration.
reset(state)[source]

Reset the metric, called before the start of an epoch.

Parameters:state – The current state dict of the Model.
train()[source]

Put the metric in train mode.

Metric Aggregators

Aggregators are a special kind of Metric which takes as input, the output from a previous metric or metrics. As a result, via a MetricTree, a series of aggregators can collect statistics such as Mean or Standard Deviation without needing to compute the underlying metric multiple times. This can, however, make the aggregators complex to use. It is therefore typically better to use the decorator API.

class torchbearer.metrics.aggregators.Mean(name)[source]

Metric aggregator which calculates the mean of process outputs between calls to reset.

Parameters:name (str) – The name of this metric.
process(*args)[source]

Add the input to the rolling sum.

Parameters:args (torch.Tensor) – The output of some previous call to Metric.process().
process_final(*args)[source]

Compute and return the mean of all metric values since the last call to reset.

Returns:The mean of the metric values since the last call to reset.
reset(state)[source]

Reset the running count and total.

Parameters:state (dict) – The model state.
class torchbearer.metrics.aggregators.RunningMean(name, batch_size=50, step_size=10)[source]

A RunningMetric which outputs the mean of a sequence of its input over the course of an epoch.

Parameters:
  • name (str) – The name of this running mean.
  • batch_size (int) – The size of the deque to store of previous results.
  • step_size (int) – The number of iterations between aggregations.
class torchbearer.metrics.aggregators.RunningMetric(name, batch_size=50, step_size=10)[source]

A metric which aggregates batches of results and presents a method to periodically process these into a value.

Note

Running metrics only provide output during training.

Parameters:
  • name (str) – The name of the metric.
  • batch_size (int) – The size of the deque to store of previous results.
  • step_size (int) – The number of iterations between aggregations.
process_train(*args)[source]

Add the current metric value to the cache and call ‘_step’ is needed.

Parameters:args – The output of some Metric
Returns:The current metric value.
reset(state)[source]

Reset the step counter. Does not clear the cache.

Parameters:state (dict) – The current model state.
class torchbearer.metrics.aggregators.Std(name)[source]

Metric aggregator which calculates the standard deviation of process outputs between calls to reset.

Parameters:name (str) – The name of this metric.
process(*args)[source]

Compute values required for the std from the input.

Parameters:args (torch.Tensor) – The output of some previous call to Metric.process().
process_final(*args)[source]

Compute and return the final standard deviation.

Returns:The standard deviation of each observation since the last reset call.
reset(state)[source]

Reset the statistics to compute the next deviation.

Parameters:state (dict) – The model state.

Base Metrics

Base metrics are the base classes which represent the metrics supplied with torchbearer. They all use the default_for_key() decorator so that they can be accessed in the call to torchbearer.Model via the following strings:

class torchbearer.metrics.default.DefaultAccuracy[source]

The default accuracy metric loads in a different accuracy metric depending on the loss function or criterion in use at the start of training. Default for keys: acc, accuracy. The following bindings are in place for both nn and functional variants:

eval(data_key=None)[source]

Put the metric in eval mode during model validation.

process(*args)[source]

MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.

If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.

Attributes and the return value of a MagicMock will also be MagicMocks.

process_final(*args)[source]

MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.

If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.

Attributes and the return value of a MagicMock will also be MagicMocks.

reset(state)[source]

Reset the metric, called before the start of an epoch.

Parameters:state – The current state dict of the Model.
train()[source]

Put the metric in train mode during model training.

class torchbearer.metrics.primitives.BinaryAccuracy

Binary accuracy metric. Uses torch.eq to compare predictions to targets. Decorated with a mean and running_mean. Default for key: ‘binary_acc’.

class torchbearer.metrics.primitives.CategoricalAccuracy(ignore_index=-100)

Categorical accuracy metric. Uses torch.max to determine predictions and compares to targets. Decorated with a mean, running_mean and std. Default for key: ‘cat_acc’

Parameters:ignore_index (int) – Specifies a target value that is ignored and does not contribute to the metric output. See https://pytorch.org/docs/stable/nn.html#crossentropyloss
class torchbearer.metrics.primitives.TopKCategoricalAccuracy(k=5, ignore_index=-100)

Top K Categorical accuracy metric. Uses torch.topk to determine the top k predictions and compares to targets. Decorated with a mean, running_mean and std. Default for keys: ‘top_5_acc’, ‘top_10_acc’.

Parameters:ignore_index (int) –

Specifies a target value that is ignored and does not contribute to the metric output. See https://pytorch.org/docs/stable/nn.html#crossentropyloss

class torchbearer.metrics.primitives.MeanSquaredError

Mean squared error metric. Computes the pixelwise squared error which is then averaged with decorators. Decorated with a mean and running_mean. Default for key: ‘mse’.

class torchbearer.metrics.primitives.Loss

Simply returns the ‘loss’ value from the model state. Decorated with a mean, running_mean and std. Default for key: ‘loss’.

class torchbearer.metrics.primitives.Epoch

Returns the ‘epoch’ from the model state. Default for key: ‘epoch’.

class torchbearer.metrics.roc_auc_score.RocAucScore(one_hot_labels=True, one_hot_offset=0, one_hot_classes=10)

Area Under ROC curve metric. Default for keys: ‘roc_auc’, ‘roc_auc_score’.

Note

Requires sklearn.metrics.

Parameters:
  • one_hot_labels (bool) – If True, convert the labels to a one hot encoding. Required if they are not already.
  • one_hot_offset (int) – Subtracted from class labels, use if not already zero based.
  • one_hot_classes (int) – Number of classes for the one hot encoding.

Timer

class torchbearer.metrics.timer.TimerMetric(time_keys=())[source]
get_timings()[source]
on_backward(state)[source]

Perform some action with the given state as context after backward has been called on the loss.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion(state)[source]

Perform some action with the given state as context after the criterion has been evaluated.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_criterion_validation(state)[source]

Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end(state)[source]

Perform some action with the given state as context at the end of the model fitting.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_epoch(state)[source]

Perform some action with the given state as context at the end of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_training(state)[source]

Perform some action with the given state as context after the training loop has completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_end_validation(state)[source]

Perform some action with the given state as context at the end of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward(state)[source]

Perform some action with the given state as context after the forward pass (model output) has been completed.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_forward_validation(state)[source]

Perform some action with the given state as context after the forward pass (model output) has been completed with the validation data.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample(state)[source]

Perform some action with the given state as context after data has been sampled from the generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_sample_validation(state)[source]

Perform some action with the given state as context after data has been sampled from the validation generator.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start(state)[source]

Perform some action with the given state as context at the start of a model fit.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_epoch(state)[source]

Perform some action with the given state as context at the start of each epoch.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_training(state)[source]

Perform some action with the given state as context at the start of the training loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_start_validation(state)[source]

Perform some action with the given state as context at the start of the validation loop.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_training(state)[source]

Perform some action with the given state as context after step has been called on the optimiser.

Parameters:state (dict[str,any]) – The current state dict of the Model.
on_step_validation(state)[source]

Perform some action with the given state as context at the end of each validation step.

Parameters:state (dict[str,any]) – The current state dict of the Model.
process(*args)[source]

MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.

If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.

Attributes and the return value of a MagicMock will also be MagicMocks.

reset(state)[source]

Reset the metric, called before the start of an epoch.

Parameters:state – The current state dict of the Model.
update_time(text, metric, state)[source]