torchbearer.metrics¶
Base Classes¶
The base metric classes exist to enable complex data flow requirements between metrics. All metrics are either instances
of Metric or MetricFactory. These can then be collected in a MetricList or a
MetricTree. The MetricList simply aggregates calls from a list of metrics, whereas the
MetricTree will pass data from its root metric to each child and collect the outputs. This enables complex
running metrics and statistics, without needing to compute the underlying values more than once. Typically,
constructions of this kind should be handled using the decorator API.
-
class
torchbearer.metrics.metrics.AdvancedMetric(name)[source]¶ The
AdvancedMetricclass is a metric which provides different process methods for training and validation. This enables running metrics which do not output intermediate steps during validation.Parameters: name (str) – The name of the metric. -
process(*args)[source]¶ Depending on the current mode, return the result of either ‘process_train’ or ‘process_validate’.
Parameters: state (dict) – The current state dict of the Model.Returns: The metric value.
-
process_final(*args)[source]¶ Depending on the current mode, return the result of either ‘process_final_train’ or ‘process_final_validate’.
Parameters: state (dict) – The current state dict of the Model.Returns: The final metric value.
-
process_final_train(*args)[source]¶ Process the given state and return the final metric value for a training iteration.
Parameters: state – The current state dict of the Model.Returns: The final metric value for a training iteration.
-
process_final_validate(*args)[source]¶ Process the given state and return the final metric value for a validation iteration.
Parameters: state (dict) – The current state dict of the Model.Returns: The final metric value for a validation iteration.
-
process_train(*args)[source]¶ Process the given state and return the metric value for a training iteration.
Parameters: state – The current state dict of the Model.Returns: The metric value for a training iteration.
-
-
class
torchbearer.metrics.metrics.Metric(name)[source]¶ Base metric class. Process will be called on each batch, process-final at the end of each epoch. The metric contract allows for metrics to take any args but not kwargs. The initial metric call will be given state, however, subsequent metrics can pass any values desired.
Note
All metrics must extend this class.
Parameters: name (str) – The name of the metric -
process(*args)[source]¶ Process the state and update the metric for one iteration.
Parameters: args – Arguments given to the metric. If this is a root level metric, will be given state Returns: None, or the value of the metric for this batch
-
process_final(*args)[source]¶ Process the terminal state and output the final value of the metric.
Parameters: args – Arguments given to the metric. If this is a root level metric, will be given state Returns: None or the value of the metric for this epoch
-
-
class
torchbearer.metrics.metrics.MetricFactory[source]¶ A simple implementation of a factory pattern. Used to enable construction of complex metrics using decorators.
-
class
torchbearer.metrics.metrics.MetricList(metric_list)[source]¶ The
MetricListclass is a wrapper for a list of metrics which acts as a single metric and produces a dictionary of outputs.Parameters: metric_list (list) – The list of metrics to be wrapped. If the list contains a MetricList, this will be unwrapped. Any strings in the list will be retrieved from metrics.DEFAULT_METRICS.-
process(state)[source]¶ Process each metric an wrap in a dictionary which maps metric names to values.
Parameters: state – The current state dict of the Model.Returns: dict[str,any] – A dictionary which maps metric names to values.
-
process_final(state)[source]¶ Process each metric an wrap in a dictionary which maps metric names to values.
Parameters: state – The current state dict of the Model.Returns: dict[str,any] – A dictionary which maps metric names to values.
-
-
class
torchbearer.metrics.metrics.MetricTree(metric)[source]¶ A tree structure which has a node
Metricand some children. Upon execution, the node is called with the input and its output is passed to each of the children. A dict is updated with the results.Parameters: metric (Metric) – The metric to act as the root node of the tree / subtree -
add_child(child)[source]¶ Add a child to this node of the tree
Parameters: child (Metric) – The child to add Returns: None
-
process(*args)[source]¶ Process this node and then pass the output to each child.
Returns: A dict containing all results from the children
-
process_final(*args)[source]¶ Process this node and then pass the output to each child.
Returns: A dict containing all results from the children
-
Decorators - The Decorator API¶
The decorator API is the core way to interact with metrics in torchbearer. All of the classes and functionality handled
here can be reproduced by manually interacting with the classes if necessary. Broadly speaking, the decorator API is
used to construct a MetricFactory which will build a MetricTree that handles data flow between
instances of Mean, RunningMean, Std etc.
-
torchbearer.metrics.decorators.default_for_key(key)[source]¶ The
default_for_key()decorator will register the given metric in the global metric dict (metrics.DEFAULT_METRICS) so that it can be referenced by name in instances ofMetricListsuch as in the list given to thetorchbearer.Model.Example:
@default_for_key('acc') class CategoricalAccuracy(metrics.BatchLambda): ...
Parameters: key (str) – The key to use when referencing the metric
-
torchbearer.metrics.decorators.lambda_metric(name, on_epoch=False)[source]¶ The
lambda_metric()decorator is used to convert a lambda function y_pred, y_true into aMetricinstance. In fact it return aMetricFactorywhich is used to build aMetric. This can make things complicated as in the following example:@metrics.lambda_metric('my_metric') def my_metric(y_pred, y_true): ... # Calculate some metric model = Model(metrics=[my_metric()]) # Note we have to call `my_metric` in order to instantiate the class
Parameters: - name – The name of the metric (e.g. ‘loss’)
- on_epoch – If True the metric will be an instance of
EpochLambdainstead ofBatchLambda
Returns: A decorator which replaces a function with a
MetricFactory
-
torchbearer.metrics.decorators.mean(clazz)[source]¶ The
mean()decorator is used to add aMeanto theMetricTreewhich will will output a mean value at the end of each epoch. At build time, if the inner class is not aMetricTree, one will be created. TheMeanwill also be wrapped in aToDictfor simplicity.Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.mean ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = my_metric().build() >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> metric.process_final() {'my_metric': 6.0}
Parameters: clazz – The class to decorate Returns: A MetricFactorywhich can be instantiated and built to append aMeanto theMetricTree
-
torchbearer.metrics.decorators.running_mean(clazz=None, batch_size=50, step_size=10)[source]¶ The
running_mean()decorator is used to add aRunningMeanto theMetricTree. As with the other decorators, aMetricFactoryis created which will do this upon the call toMetricFactory.build(). If the inner class is not / does not build aMetricTreethen one will be created. TheRunningMeanwill be wrapped in aToDict(with ‘running_’ prepended to the name) for simplicity.Note
The decorator function does not need to be called if not desired, both: @running_mean and @running_mean() are acceptable.
Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.running_mean(step_size=2) # Update every 2 steps ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = my_metric().build() >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {'running_my_metric': 4.0} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {'running_my_metric': 4.0} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8, triggers update {'running_my_metric': 6.0}
Parameters: - clazz – The class to decorate
- batch_size – See
RunningMean - step_size – See
RunningMean
Returns: decorator or
MetricFactory
-
torchbearer.metrics.decorators.std(clazz)[source]¶ The
std()decorator is used to add aStdto theMetricTreewhich will will output a population standard deviation value at the end of each epoch. At build time, if the inner class is not aMetricTree, one will be created. TheStdwill also be wrapped in aToDict(with ‘_std’ appended) for simplicity.Example:
>>> import torch >>> from torchbearer import metrics >>> @metrics.std ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = my_metric().build() >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> '%.4f' % metric.process_final()['my_metric_std'] '1.6330'
Parameters: clazz – The class to decorate Returns: A MetricFactorywhich can be instantiated and built to append aMeanto theMetricTree
-
torchbearer.metrics.decorators.to_dict(clazz)[source]¶ The
to_dict()decorator is used to wrap either aMetricorMetricFactoryinstance with aToDictinstance. The result is that future output will be wrapped in a dict[name, value].Example:
>>> from torchbearer import metrics >>> @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> my_metric().build().process({'y_pred':4, 'y_true':5}) 9 >>> @metrics.to_dict ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> my_metric().build().process({'y_pred':4, 'y_true':5}) {'my_metric': 9}
Parameters: clazz – The class to decorate Returns: A MetricFactorywhich can be instantiated and built to wrap the given class in aToDict
Metric Wrappers¶
Metric wrappers are classes which wrap instances of Metric or, in the case of EpochLambda and
BatchLambda, functions. Typically, these should not be used directly (although this is entirely possible),
but via the decorator API.
-
class
torchbearer.metrics.wrappers.BatchLambda(name, metric_function)[source]¶ A metric which returns the output of the given function on each batch.
Parameters: - name (str) – The name of the metric.
- metric_function – A metric function(‘y_pred’, ‘y_true’) to wrap.
-
process(state)[source]¶ Return the output of the wrapped function.
Parameters: state (dict) – The torchbearer.Modelstate.Returns: The value of the metric function(‘y_pred’, ‘y_true’).
-
class
torchbearer.metrics.wrappers.EpochLambda(name, metric_function, running=True, step_size=50)[source]¶ A metric wrapper which computes the given function for concatenated values of ‘y_true’ and ‘y_pred’ each epoch. Can be used as a running metric which computes the function for batches of outputs with a given step size during training.
Parameters: - name (str) – The name of the metric.
- metric_function – The function(‘y_pred’, ‘y_true’) to use as the metric.
- running (bool) – True if this should act as a running metric.
- step_size (int) – Step size to use between calls if running=True.
-
process_final_train(state)[source]¶ Evaluate the function with the aggregated outputs.
Parameters: state (dict) – The torchbearer.Modelstate.Returns: The result of the function.
-
process_final_validate(state)[source]¶ Evaluate the function with the aggregated outputs.
Parameters: state (dict) – The torchbearer.Modelstate.Returns: The result of the function.
-
process_train(state)[source]¶ Concatenate the ‘y_true’ and ‘y_pred’ from the state along the 0 dimension. If this is a running metric, evaluates the function every number of steps.
Parameters: state (dict) – The torchbearer.Modelstate.Returns: The current running result.
-
process_validate(state)[source]¶ During validation, just concatenate ‘y_true’ and y_pred’.
Parameters: state (dict) – The torchbearer.Modelstate.
-
reset(state)[source]¶ Reset the ‘y_true’ and ‘y_pred’ caches.
Parameters: state (dict) – The torchbearer.Modelstate.
-
class
torchbearer.metrics.wrappers.ToDict(metric)[source]¶ The
ToDictclass is anAdvancedMetricwhich will put output from the innerMetricin a dict (mapping metric name to value) before returning. When in eval mode, ‘val_’ will be prepended to the metric name.Example:
>>> from torchbearer import metrics >>> @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = metrics.ToDict(my_metric().build()) >>> metric.process({'y_pred': 4, 'y_true': 5}) {'my_metric': 9} >>> metric.eval() >>> metric.process({'y_pred': 4, 'y_true': 5}) {'val_my_metric': 9}
Parameters: metric (metrics.Metric) – The Metricinstance to wrap.-
process_final_train(*args)[source]¶ Process the given state and return the final metric value for a training iteration.
Parameters: state – The current state dict of the Model.Returns: The final metric value for a training iteration.
-
process_final_validate(*args)[source]¶ Process the given state and return the final metric value for a validation iteration.
Parameters: state (dict) – The current state dict of the Model.Returns: The final metric value for a validation iteration.
-
process_train(*args)[source]¶ Process the given state and return the metric value for a training iteration.
Parameters: state – The current state dict of the Model.Returns: The metric value for a training iteration.
-
process_validate(*args)[source]¶ Process the given state and return the metric value for a validation iteration.
Parameters: state – The current state dict of the Model.Returns: The metric value for a validation iteration.
-
Metric Aggregators¶
Aggregators are a special kind of Metric which takes as input, the output from a previous metric or metrics.
As a result, via a MetricTree, a series of aggregators can collect statistics such as Mean or Standard
Deviation without needing to compute the underlying metric multiple times. This can, however, make the aggregators
complex to use. It is therefore typically better to use the decorator API.
-
class
torchbearer.metrics.aggregators.Mean(name)[source]¶ Metric aggregator which calculates the mean of process outputs between calls to reset.
Parameters: name (str) – The name of this metric. -
process(data)[source]¶ Add the input to the rolling sum.
Parameters: data (torch.Tensor) – The output of some previous call to Metric.process().
-
process_final(data)[source]¶ Compute and return the mean of all metric values since the last call to reset.
Parameters: data (torch.Tensor) – The output of some previous call to Metric.process_final().Returns: The mean of the metric values since the last call to reset.
-
-
class
torchbearer.metrics.aggregators.RunningMean(name, batch_size=50, step_size=10)[source]¶ A
RunningMetricwhich outputs the mean of a sequence of its input over the course of an epoch.Parameters: - name (str) – The name of this running mean.
- batch_size (int) – The size of the deque to store of previous results.
- step_size (int) – The number of iterations between aggregations.
-
class
torchbearer.metrics.aggregators.RunningMetric(name, batch_size=50, step_size=10)[source]¶ A metric which aggregates batches of results and presents a method to periodically process these into a value.
Note
Running metrics only provide output during training.
Parameters: - name (str) – The name of the metric.
- batch_size (int) – The size of the deque to store of previous results.
- step_size (int) – The number of iterations between aggregations.
-
class
torchbearer.metrics.aggregators.Std(name)[source]¶ Metric aggregator which calculates the standard deviation of process outputs between calls to reset.
Parameters: name (str) – The name of this metric. -
process(data)[source]¶ Compute values required for the std from the input.
Parameters: data (torch.Tensor) – The output of some previous call to Metric.process().
-
process_final(data)[source]¶ Compute and return the final standard deviation.
Parameters: data (torch.Tensor) – The output of some previous call to Metric.process_final().Returns: The standard deviation of each observation since the last reset call.
-
Base Metrics¶
Base metrics are the base classes which represent the metrics supplied with torchbearer. The all use the
default_for_key() decorator so that they can be accessed in the call to torchbearer.Model via the
following strings:
- ‘acc’ or ‘accuracy’: The
CategoricalAccuracymetric - ‘loss’: The
Lossmetric - ‘epoch’: The
Epochmetric - ‘roc_auc’ or ‘roc_auc_score’: The
RocAucScoremetric
-
class
torchbearer.metrics.primitives.CategoricalAccuracy[source]¶ Categorical accuracy metric. Uses torch.max to determine predictions and compares to targets.
-
class
torchbearer.metrics.primitives.Epoch[source]¶ Returns the ‘epoch’ from the model state.