torchbearer.metrics
The base metric classes exist to enable complex data flow requirements between metrics. All metrics are either instances
of Metric
or MetricFactory
. These can then be collected in a MetricList
or a
MetricTree
. The MetricList
simply aggregates calls from a list of metrics, whereas the
MetricTree
will pass data from its root metric to each child and collect the outputs. This enables complex
running metrics and statistics, without needing to compute the underlying values more than once. Typically,
constructions of this kind should be handled using the decorator API
.
Base Classes
- class torchbearer.bases.Metric(name)[source]
Base metric class. Process will be called on each batch, process-final at the end of each epoch. The metric contract allows for metrics to take any args but not kwargs. The initial metric call will be given state, however, subsequent metrics can pass any values desired.
Note
All metrics must extend this class.
- Parameters:
name (str) – The name of the metric
- process(*args)[source]
Process the state and update the metric for one iteration.
- Parameters:
args – Arguments given to the metric. If this is a root level metric, will be given state
- Returns:
None, or the value of the metric for this batch
- process_final(*args)[source]
Process the terminal state and output the final value of the metric.
- Parameters:
args – Arguments given to the metric. If this is a root level metric, will be given state
- Returns:
None or the value of the metric for this epoch
- class torchbearer.metrics.metrics.AdvancedMetric(name)[source]
The
AdvancedMetric
class is a metric which provides different process methods for training and validation. This enables running metrics which do not output intermediate steps during validation.- Parameters:
name (str) – The name of the metric.
- eval(data_key=None)[source]
Put the metric in eval mode.
- Parameters:
data_key (StateKey) – The torchbearer data_key, if used
- process(*args)[source]
Depending on the current mode, return the result of either ‘process_train’ or ‘process_validate’.
- Returns:
The metric value.
- process_final(*args)[source]
Depending on the current mode, return the result of either ‘process_final_train’ or ‘process_final_validate’.
- Returns:
The final metric value.
- process_final_train(*args)[source]
Process the given state and return the final metric value for a training iteration.
- Returns:
The final metric value for a training iteration.
- process_final_validate(*args)[source]
Process the given state and return the final metric value for a validation iteration.
- Returns:
The final metric value for a validation iteration.
- process_train(*args)[source]
Process the given state and return the metric value for a training iteration.
- Returns:
The metric value for a training iteration.
- class torchbearer.metrics.metrics.MetricList(metric_list)[source]
The
MetricList
class is a wrapper for a list of metrics which acts as a single metric and produces a dictionary of outputs.- Parameters:
metric_list (list) – The list of metrics to be wrapped. If the list contains a
MetricList
, this will be unwrapped. Any strings in the list will be retrieved from metrics.DEFAULT_METRICS.
- process(*args)[source]
Process each metric an wrap in a dictionary which maps metric names to values.
- Returns:
A dictionary which maps metric names to values.
- Return type:
dict[str,any]
- process_final(*args)[source]
Process each metric an wrap in a dictionary which maps metric names to values.
- Returns:
A dictionary which maps metric names to values.
- Return type:
dict[str,any]
- class torchbearer.metrics.metrics.MetricTree(metric)[source]
A tree structure which has a node
Metric
and some children. Upon execution, the node is called with the input and its output is passed to each of the children. A dict is updated with the results.Note
If the node output is already a dict (i.e. the node is a standalone metric), this is unwrapped before passing the first value to the children.
- Parameters:
metric (Metric) – The metric to act as the root node of the tree / subtree
- add_child(child)[source]
Add a child to this node of the tree
- Parameters:
child (Metric) – The child to add
- process(*args)[source]
Process this node and then pass the output to each child.
- Returns:
A dict containing all results from the children
- process_final(*args)[source]
Process this node and then pass the output to each child.
- Returns:
A dict containing all results from the children
Decorators - The Decorator API
The decorator API is the core way to interact with metrics in torchbearer. All of the classes and functionality handled
here can be reproduced by manually interacting with the classes if necessary. Broadly speaking, the decorator API is
used to construct a MetricFactory
which will build a MetricTree
that handles data flow between
instances of Mean
, RunningMean
, Std
etc.
- torchbearer.metrics.decorators.default_for_key(key, *args, **kwargs)[source]
The
default_for_key()
decorator will register the given metric in the global metric dict (metrics.DEFAULT_METRICS) so that it can be referenced by name in instances ofMetricList
such as in the list given to thetorchbearer.Model
.Example:
@default_for_key('acc') class CategoricalAccuracy(metrics.BatchLambda): ...
- Parameters:
key (str) – The key to use when referencing the metric
args – Any args to pass to the underlying metric when constructed
kwargs – Any keyword args to pass to the underlying metric when constructed
- torchbearer.metrics.decorators.lambda_metric(name, on_epoch=False)[source]
The
lambda_metric()
decorator is used to convert a lambda function y_pred, y_true into aMetric
instance. This can be used as in the following example:@metrics.lambda_metric('my_metric') def my_metric(y_pred, y_true): ... # Calculate some metric model = Model(metrics=[my_metric])
- Parameters:
name (str) – The name of the metric (e.g. ‘loss’)
on_epoch (bool) – If True the metric will be an instance of
EpochLambda
instead ofBatchLambda
- Returns:
A decorator which replaces a function with a
Metric
- torchbearer.metrics.decorators.mean(clazz=None, dim=None)[source]
The
mean()
decorator is used to add aMean
to theMetricTree
which will will output a mean value at the end of each epoch. At build time, if the inner class is not aMetricTree
, one will be created. TheMean
will also be wrapped in aToDict
for simplicity.Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.mean ... @metrics.lambda_metric('my_metric') ... def metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> metric.process_final() {'my_metric': 6.0}
- Parameters:
clazz – The class to decorate
dim (int, tuple) – See
Mean
- Returns:
A
MetricTree
with aMean
appended or a wrapper class that extendsMetricTree
- torchbearer.metrics.decorators.running_mean(clazz=None, batch_size=50, step_size=10, dim=None)[source]
The
running_mean()
decorator is used to add aRunningMean
to theMetricTree
. If the inner class is not aMetricTree
then one will be created. TheRunningMean
will be wrapped in aToDict
(with ‘running_’ prepended to the name) for simplicity.Note
The decorator function does not need to be called if not desired, both: @running_mean and @running_mean() are acceptable.
Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.running_mean(step_size=2) # Update every 2 steps ... @metrics.lambda_metric('my_metric') ... def metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {'running_my_metric': 4.0} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {'running_my_metric': 4.0} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8, triggers update {'running_my_metric': 6.0}
- Parameters:
clazz – The class to decorate
batch_size (int) – See
RunningMean
step_size (int) – See
RunningMean
dim (int, tuple) – See
RunningMean
- Returns:
decorator or
MetricTree
instance or wrapper
- torchbearer.metrics.decorators.std(clazz=None, unbiased=True, dim=None)[source]
The
std()
decorator is used to add aStd
to theMetricTree
which will will output a sample standard deviation value at the end of each epoch. At build time, if the inner class is not aMetricTree
, one will be created. TheStd
will also be wrapped in aToDict
(with ‘_std’ appended) for simplicity.Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.std ... @metrics.lambda_metric('my_metric') ... def metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> '%.4f' % metric.process_final()['my_metric_std'] '2.0000'
- Parameters:
- Returns:
A
MetricTree
with aStd
appended or a wrapper class that extendsMetricTree
- torchbearer.metrics.decorators.to_dict(clazz)[source]
The
to_dict()
decorator is used to wrap either aMetric
class or aMetric
instance with aToDict
instance. The result is that future output will be wrapped in a dict[name, value].Example:
>>> from torchbearer import metrics >>> @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> my_metric.process({'y_pred':4, 'y_true':5}) 9 >>> @metrics.to_dict ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> my_metric.process({'y_pred':4, 'y_true':5}) {'my_metric': 9}
- torchbearer.metrics.decorators.var(clazz=None, unbiased=True, dim=None)[source]
The
var()
decorator is used to add aVar
to theMetricTree
which will will output a sample variance value at the end of each epoch. At build time, if the inner class is not aMetricTree
, one will be created. TheVar
will also be wrapped in aToDict
(with ‘_var’ appended) for simplicity.Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.var ... @metrics.lambda_metric('my_metric') ... def metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> '%.4f' % metric.process_final()['my_metric_var'] '4.0000'
- Parameters:
- Returns:
A
MetricTree
with aVar
appended or a wrapper class that extendsMetricTree
Metric Wrappers
Metric wrappers are classes which wrap instances of Metric
or, in the case of EpochLambda
and
BatchLambda
, functions. Typically, these should not be used directly (although this is entirely possible),
but via the decorator API
.
- class torchbearer.metrics.wrappers.BatchLambda(name, metric_function)[source]
A metric which returns the output of the given function on each batch.
- Parameters:
name (str) – The name of the metric.
metric_function (func) – A metric function(‘y_pred’, ‘y_true’) to wrap.
- process(*args)[source]
Return the output of the wrapped function.
- Parameters:
args – The
torchbearer.Trial
state.- Returns:
The value of the metric function(‘y_pred’, ‘y_true’).
- class torchbearer.metrics.wrappers.EpochLambda(name, metric_function, running=True, step_size=50)[source]
A metric wrapper which computes the given function for concatenated values of ‘y_true’ and ‘y_pred’ each epoch. Can be used as a running metric which computes the function for batches of outputs with a given step size during training.
- Parameters:
name (str) – The name of the metric.
metric_function (func) – The function(‘y_pred’, ‘y_true’) to use as the metric.
running (bool) – True if this should act as a running metric.
step_size (int) – Step size to use between calls if running=True.
- process_final_train(*args)[source]
Evaluate the function with the aggregated outputs.
- Returns:
The result of the function.
- process_final_validate(*args)[source]
Evaluate the function with the aggregated outputs.
- Returns:
The result of the function.
- process_train(*args)[source]
Concatenate the ‘y_true’ and ‘y_pred’ from the state along the 0 dimension, this must be the batch dimension. If this is a running metric, evaluates the function every number of steps.
- Parameters:
args – The
torchbearer.Trial
state.- Returns:
The current running result.
- process_validate(*args)[source]
During validation, just concatenate ‘y_true’ and y_pred’.
- Parameters:
args – The
torchbearer.Trial
state.
- reset(state)[source]
Reset the ‘y_true’ and ‘y_pred’ caches.
- Parameters:
state (dict) – The
torchbearer.Trial
state.
- class torchbearer.metrics.wrappers.ToDict(metric)[source]
The
ToDict
class is anAdvancedMetric
which will put output from the innerMetric
in a dict (mapping metric name to value) before returning. When in eval mode, ‘val_’ will be prepended to the metric name.Example:
>>> from torchbearer import metrics >>> @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = metrics.ToDict(my_metric().build()) >>> metric.process({'y_pred': 4, 'y_true': 5}) {'my_metric': 9} >>> metric.eval() >>> metric.process({'y_pred': 4, 'y_true': 5}) {'val_my_metric': 9}
- eval(data_key=None)[source]
Put the metric in eval mode.
- Parameters:
data_key (StateKey) – The torchbearer data_key, if used
- process_final_train(*args)[source]
Process the given state and return the final metric value for a training iteration.
- Returns:
The final metric value for a training iteration.
- process_final_validate(*args)[source]
Process the given state and return the final metric value for a validation iteration.
- Returns:
The final metric value for a validation iteration.
- process_train(*args)[source]
Process the given state and return the metric value for a training iteration.
- Returns:
The metric value for a training iteration.
- process_validate(*args)[source]
Process the given state and return the metric value for a validation iteration.
- Returns:
The metric value for a validation iteration.
Metric Aggregators
Aggregators are a special kind of Metric
which takes as input, the output from a previous metric or metrics.
As a result, via a MetricTree
, a series of aggregators can collect statistics such as Mean or Standard
Deviation without needing to compute the underlying metric multiple times. This can, however, make the aggregators
complex to use. It is therefore typically better to use the decorator API
.
- class torchbearer.metrics.aggregators.Mean(name, dim=None)[source]
Metric aggregator which calculates the mean of process outputs between calls to reset.
- Parameters:
name (str) – The name of this metric.
dim (int, tuple) – The dimension(s) on which to perform the mean. If left as None, this will mean over the whole Tensor
- process(*args)[source]
Add the input to the rolling sum. Input must be a torch tensor.
- Parameters:
args – The output of some previous call to
Metric.process()
.
- class torchbearer.metrics.aggregators.RunningMean(name, batch_size=50, step_size=10, dim=None)[source]
A
RunningMetric
which outputs the running mean of its input tensors over the course of an epoch.- Parameters:
name (str) – The name of this running mean.
batch_size (int) – The size of the deque to store of previous results.
step_size (int) – The number of iterations between aggregations.
dim (int, tuple) – The dimension(s) on which to perform the mean. If left as None, this will mean over the whole Tensor
- class torchbearer.metrics.aggregators.RunningMetric(name, batch_size=50, step_size=10)[source]
A metric which aggregates batches of results and presents a method to periodically process these into a value.
Note
Running metrics only provide output during training.
- Parameters:
name (str) – The name of the metric.
batch_size (int) – The size of the deque to store of previous results.
step_size (int) – The number of iterations between aggregations.
- class torchbearer.metrics.aggregators.Std(name, unbiased=True, dim=None)[source]
Metric aggregator which calculates the sample standard deviation of process outputs between calls to reset. Optionally calculate the population std if
unbiased = False
.- Parameters:
name (str) – The name of this metric.
unbiased (bool) – If True (default), calculates the sample standard deviation, else, the population standard deviation
dim (int, tuple) – The dimension(s) on which to compute the std. If left as None, this will operate over the whole Tensor
- class torchbearer.metrics.aggregators.Var(name, unbiased=True, dim=None)[source]
Metric aggregator which calculates the sample variance of process outputs between calls to reset. Optionally calculate the population variance if
unbiased = False
.- Parameters:
name (str) – The name of this metric.
unbiased (bool) – If True (default), calculates the sample variance, else, the population variance
dim (int, tuple) – The dimension(s) on which to compute the std. If left as None, this will operate over the whole Tensor
- process(*args)[source]
Compute values required for the variance from the input. The input should be a torch Tensor. The sum and sum of squares will be computed over the provided dimension.
- Parameters:
args (torch.Tensor) – The output of some previous call to
Metric.process()
.
Base Metrics
Base metrics are the base classes which represent the metrics supplied with torchbearer. They all use the
default_for_key()
decorator so that they can be accessed in the call to torchbearer.Model
via the
following strings:
‘acc’ or ‘accuracy’: The
DefaultAccuracy
metric‘binary_acc’ or ‘binary_accuracy’: The
BinaryAccuracy
metric‘cat_acc’ or ‘cat_accuracy’: The
CategoricalAccuracy
metric‘top_5_acc’ or ‘top_5_accuracy’: The
TopKCategoricalAccuracy
metric‘top_10_acc’ or ‘top_10_accuracy’: The
TopKCategoricalAccuracy
metric with k=10‘mse’: The
MeanSquaredError
metric‘loss’: The
Loss
metric‘epoch’: The
Epoch
metric‘lr’: The
LR
metric‘roc_auc’ or ‘roc_auc_score’: The
RocAucScore
metric
- class torchbearer.metrics.default.DefaultAccuracy[source]
The default accuracy metric loads in a different accuracy metric depending on the loss function or criterion in use at the start of training. Default for keys: acc, accuracy. The following bindings are in place for both nn and functional variants:
cross entropy loss ->
CategoricalAccuracy
[DEFAULT]nll loss ->
CategoricalAccuracy
mse loss ->
MeanSquaredError
bce loss ->
BinaryAccuracy
bce loss with logits ->
BinaryAccuracy
- process(*args)[source]
Process the state and update the metric for one iteration.
- Parameters:
args – Arguments given to the metric. If this is a root level metric, will be given state
- Returns:
None, or the value of the metric for this batch
- process_final(*args)[source]
Process the terminal state and output the final value of the metric.
- Parameters:
args – Arguments given to the metric. If this is a root level metric, will be given state
- Returns:
None or the value of the metric for this epoch
- class torchbearer.metrics.primitives.BinaryAccuracy
Binary accuracy metric. Uses torch.eq to compare predictions to targets. Decorated with a mean and running_mean. Default for key: ‘binary_acc’.
- class torchbearer.metrics.primitives.CategoricalAccuracy(ignore_index=-100)
Categorical accuracy metric. Uses torch.max to determine predictions and compares to targets. Decorated with a mean, running_mean and std. Default for key: ‘cat_acc’
- Parameters:
pred_key (StateKey) – The key in state which holds the predicted values
target_key (StateKey) – The key in state which holds the target values
ignore_index (int) – Specifies a target value that is ignored and does not contribute to the metric output. See https://pytorch.org/docs/stable/nn.html#crossentropyloss
- class torchbearer.metrics.primitives.TopKCategoricalAccuracy(k=5, ignore_index=-100)
Top K Categorical accuracy metric. Uses torch.topk to determine the top k predictions and compares to targets. Decorated with a mean, running_mean and std. Default for keys: ‘top_5_acc’, ‘top_10_acc’.
- Parameters:
pred_key (StateKey) – The key in state which holds the predicted values
target_key (StateKey) – The key in state which holds the target values
ignore_index (int) –
Specifies a target value that is ignored and does not contribute to the metric output. See https://pytorch.org/docs/stable/nn.html#crossentropyloss
- class torchbearer.metrics.primitives.MeanSquaredError
Mean squared error metric. Computes the pixelwise squared error which is then averaged with decorators. Decorated with a mean and running_mean. Default for key: ‘mse’.
- class torchbearer.metrics.primitives.Loss
Simply returns the ‘loss’ value from the model state. Decorated with a mean, running_mean and std. Default for key: ‘loss’.
- State Requirements:
torchbearer.state.LOSS
: This key should map to the loss for the current batch
- class torchbearer.metrics.primitives.Epoch
Returns the ‘epoch’ from the model state. Default for key: ‘epoch’.
- State Requirements:
torchbearer.state.EPOCH
: This key should map to the number of the current epoch
- class torchbearer.metrics.roc_auc_score.RocAucScore(one_hot_labels=True, one_hot_offset=0, one_hot_classes=10)
Area Under ROC curve metric. Default for keys: ‘roc_auc’, ‘roc_auc_score’.
Note
Requires
sklearn.metrics
.- Parameters:
one_hot_labels (bool) – If True, convert the labels to a one hot encoding. Required if they are not already.
one_hot_offset (int) – Subtracted from class labels, use if not already zero based.
one_hot_classes (int) – Number of classes for the one hot encoding.
Timer
- class torchbearer.metrics.timer.TimerMetric(time_keys=())[source]
-
- on_backward(state)[source]
Perform some action with the given state as context after backward has been called on the loss.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_criterion(state)[source]
Perform some action with the given state as context after the criterion has been evaluated.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_criterion_validation(state)[source]
Perform some action with the given state as context after the criterion evaluation has been completed with the validation data.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_end(state)[source]
Perform some action with the given state as context at the end of the model fitting.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_end_epoch(state)[source]
Perform some action with the given state as context at the end of each epoch.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_end_training(state)[source]
Perform some action with the given state as context after the training loop has completed.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_end_validation(state)[source]
Perform some action with the given state as context at the end of the validation loop.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_forward(state)[source]
Perform some action with the given state as context after the forward pass (model output) has been completed.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_forward_validation(state)[source]
Perform some action with the given state as context after the forward pass (model output) has been completed with the validation data.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_sample(state)[source]
Perform some action with the given state as context after data has been sampled from the generator.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_sample_validation(state)[source]
Perform some action with the given state as context after data has been sampled from the validation generator.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_start(state)[source]
Perform some action with the given state as context at the start of a model fit.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_start_epoch(state)[source]
Perform some action with the given state as context at the start of each epoch.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_start_training(state)[source]
Perform some action with the given state as context at the start of the training loop.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_start_validation(state)[source]
Perform some action with the given state as context at the start of the validation loop.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_step_training(state)[source]
Perform some action with the given state as context after step has been called on the optimiser.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- on_step_validation(state)[source]
Perform some action with the given state as context at the end of each validation step.
- Parameters:
state (dict) – The current state dict of the
Trial
.
- process(*args)[source]
Process the state and update the metric for one iteration.
- Parameters:
args – Arguments given to the metric. If this is a root level metric, will be given state
- Returns:
None, or the value of the metric for this batch