Source code for torchbearer.metrics.decorators

"""
The decorator API is the core way to interact with metrics in torchbearer. All of the classes and functionality handled
here can be reproduced by manually interacting with the classes if necessary. Broadly speaking, the decorator API is
used to construct a :class:`.MetricFactory` which will build a :class:`.MetricTree` that handles data flow between
instances of :class:`.Mean`, :class:`.RunningMean`, :class:`.Std` etc.
"""

import inspect

from torchbearer import metrics

from torchbearer.metrics import MetricFactory, EpochLambda, BatchLambda, ToDict, Mean, MetricTree, Std, RunningMean


[docs]def default_for_key(key): """The :func:`default_for_key` decorator will register the given metric in the global metric dict (`metrics.DEFAULT_METRICS`) so that it can be referenced by name in instances of :class:`.MetricList` such as in the list given to the :class:`.torchbearer.Model`. Example: :: @default_for_key('acc') class CategoricalAccuracy(metrics.BatchLambda): ... :param key: The key to use when referencing the metric :type key: str """ def decorator(arg): if inspect.isclass(arg): metric = arg() else: metric = arg metrics.DEFAULT_METRICS[key] = metric return arg return decorator
[docs]def lambda_metric(name, on_epoch=False): """The :func:`lambda_metric` decorator is used to convert a lambda function `y_pred, y_true` into a :class:`.Metric` instance. In fact it return a :class:`.MetricFactory` which is used to build a :class:`.Metric`. This can make things complicated as in the following example: :: @metrics.lambda_metric('my_metric') def my_metric(y_pred, y_true): ... # Calculate some metric model = Model(metrics=[my_metric()]) # Note we have to call `my_metric` in order to instantiate the class :param name: The name of the metric (e.g. 'loss') :param on_epoch: If True the metric will be an instance of :class:`.EpochLambda` instead of :class:`.BatchLambda` :return: A decorator which replaces a function with a :class:`.MetricFactory` """ def decorator(metric_function): class LambdaFactory(MetricFactory): def build(self): if on_epoch: return EpochLambda(name, metric_function) else: return BatchLambda(name, metric_function) return LambdaFactory return decorator
[docs]def to_dict(clazz): """The :func:`to_dict` decorator is used to wrap either a :class:`.Metric` or :class:`.MetricFactory` instance with a :class:`.ToDict` instance. The result is that future output will be wrapped in a `dict[name, value]`. Example: :: >>> from torchbearer import metrics >>> @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> my_metric().build().process({'y_pred':4, 'y_true':5}) 9 >>> @metrics.to_dict ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> my_metric().build().process({'y_pred':4, 'y_true':5}) {'my_metric': 9} :param clazz: The class to *decorate* :return: A :class:`.MetricFactory` which can be instantiated and built to wrap the given class in a :class:`.ToDict` """ class DictFactory(MetricFactory): def __init__(self, *args, **kwargs): self.inner = clazz(*args, **kwargs) def build(self): if isinstance(self.inner, MetricFactory): inner = self.inner.build() else: inner = self.inner return ToDict(inner) return DictFactory
[docs]def mean(clazz): """The :func:`mean` decorator is used to add a :class:`.Mean` to the :class:`.MetricTree` which will will output a mean value at the end of each epoch. At build time, if the inner class is not a :class:`.MetricTree`, one will be created. The :class:`.Mean` will also be wrapped in a :class:`.ToDict` for simplicity. Example: :: >>> import torch >>> from torchbearer import metrics >>> @metrics.mean ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = my_metric().build() >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> metric.process_final() {'my_metric': 6.0} :param clazz: The class to *decorate* :return: A :class:`.MetricFactory` which can be instantiated and built to append a :class:`.Mean` to the\ :class:`.MetricTree` """ class MeanFactory(MetricFactory): def __init__(self, *args, **kwargs): self.inner = clazz(*args, **kwargs) def build(self): if isinstance(self.inner, MetricFactory): inner = self.inner.build() else: inner = self.inner if not isinstance(inner, MetricTree): inner = MetricTree(inner) inner.add_child(ToDict(Mean(inner.name))) return inner return MeanFactory
[docs]def std(clazz): """The :func:`std` decorator is used to add a :class:`.Std` to the :class:`.MetricTree` which will will output a population standard deviation value at the end of each epoch. At build time, if the inner class is not a :class:`.MetricTree`, one will be created. The :class:`.Std` will also be wrapped in a :class:`.ToDict` (with '_std' appended) for simplicity. Example: :: >>> import torch >>> from torchbearer import metrics >>> @metrics.std ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = my_metric().build() >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8 {} >>> '%.4f' % metric.process_final()['my_metric_std'] '1.6330' :param clazz: The class to *decorate* :return: A :class:`.MetricFactory` which can be instantiated and built to append a :class:`.Mean` to the\ :class:`.MetricTree` """ class StdFactory(MetricFactory): def __init__(self, *args, **kwargs): self.inner = clazz(*args, **kwargs) def build(self): if isinstance(self.inner, MetricFactory): inner = self.inner.build() else: inner = self.inner if not isinstance(inner, MetricTree): inner = MetricTree(inner) inner.add_child(ToDict(Std(inner.name + '_std'))) return inner return StdFactory
[docs]def running_mean(clazz=None, batch_size=50, step_size=10): """The :func:`running_mean` decorator is used to add a :class:`.RunningMean` to the :class:`.MetricTree`. As with the other decorators, a :class:`.MetricFactory` is created which will do this upon the call to :meth:`.MetricFactory.build`. If the inner class is not / does not build a :class:`.MetricTree` then one will be created. The :class:`.RunningMean` will be wrapped in a :class:`.ToDict` (with 'running\_' prepended to the name) for simplicity. .. note:: The decorator function does not need to be called if not desired, both: `@running_mean` and `@running_mean()` are acceptable. Example: :: >>> import torch >>> from torchbearer import metrics >>> @metrics.running_mean(step_size=2) # Update every 2 steps ... @metrics.lambda_metric('my_metric') ... def my_metric(y_pred, y_true): ... return y_pred + y_true ... >>> metric = my_metric().build() >>> metric.reset({}) >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4 {'running_my_metric': 4.0} >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6 {'running_my_metric': 4.0} >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8, triggers update {'running_my_metric': 6.0} :param clazz: The class to *decorate* :param batch_size: See :class:`.RunningMean` :param step_size: See :class:`.RunningMean` :return: decorator or :class:`.MetricFactory` """ class RunningMeanFactory(MetricFactory): def __init__(self, *args, **kwargs): self.inner = clazz(*args, **kwargs) def build(self): if isinstance(self.inner, MetricFactory): inner = self.inner.build() else: inner = self.inner if not isinstance(inner, MetricTree): inner = MetricTree(inner) inner.add_child(ToDict(RunningMean('running_' + inner.name, batch_size=batch_size, step_size=step_size))) return inner if clazz is None: def decorator(clazz): return running_mean(clazz, batch_size=batch_size, step_size=step_size) return decorator return RunningMeanFactory