"""
Base metrics are the base classes which represent the metrics supplied with torchbearer. The all use the
:func:`.default_for_key` decorator so that they can be accessed in the call to :class:`.torchbearer.Model` via the
following strings:
- '`acc`' or '`accuracy`': The :class:`.CategoricalAccuracy` metric
- '`loss`': The :class:`.Loss` metric
- '`epoch`': The :class:`.Epoch` metric
- '`roc_auc`' or '`roc_auc_score`': The :class:`.RocAucScore` metric
"""
import torchbearer
from torchbearer import metrics
import torch
[docs]class CategoricalAccuracy(metrics.BatchLambda):
"""Categorical accuracy metric. Uses torch.max to determine predictions and compares to targets.
"""
def __init__(self):
def metric_function(y_pred, y_true):
_, y_pred = torch.max(y_pred, 1)
return (y_pred == y_true).float()
super(CategoricalAccuracy, self).__init__('acc', metric_function)
@metrics.default_for_key('acc')
@metrics.default_for_key('accuracy')
@metrics.running_mean
@metrics.std
@metrics.mean
class CategoricalAccuracyFactory(metrics.MetricFactory):
def build(self):
return CategoricalAccuracy()
[docs]class Loss(metrics.Metric):
"""Simply returns the 'loss' value from the model state.
"""
def __init__(self):
super().__init__('loss')
[docs] def process(self, *args):
state = args[0]
return state[torchbearer.LOSS]
@metrics.default_for_key('loss')
@metrics.running_mean
@metrics.std
@metrics.mean
class LossFactory(metrics.MetricFactory):
def build(self):
return Loss()
[docs]class Epoch(metrics.Metric):
"""Returns the 'epoch' from the model state.
"""
def __init__(self):
super().__init__('epoch')
[docs] def process_final(self, *args):
state = args[0]
return Epoch._process(state)
[docs] def process(self, *args):
state = args[0]
return Epoch._process(state)
@staticmethod
def _process(state):
return state[torchbearer.EPOCH]
@metrics.default_for_key('epoch')
@metrics.to_dict
class EpochFactory(metrics.MetricFactory):
def build(self):
return Epoch()