from torchbearer import metrics
import sklearn.metrics
import numpy as np
[docs]@metrics.default_for_key('roc_auc')
@metrics.default_for_key('roc_auc_score')
class RocAucScore(metrics.EpochLambda):
"""Area Under ROC curve metric.
.. note::
Requires :mod:`sklearn.metrics`.
:param one_hot_labels: If True, convert the labels to a one hot encoding. Required if they are not already.
:type one_hot_labels: bool
:param one_hot_offset: Subtracted from class labels, use if not already zero based.
:type one_hot_offset: int
:param one_hot_classes: Number of classes for the one hot encoding.
:type one_hot_classes: int
"""
def __init__(self, one_hot_labels=True, one_hot_offset=0, one_hot_classes=10):
def to_categorical(y):
return np.eye(one_hot_classes, dtype='uint8')[y - one_hot_offset]
if one_hot_labels:
process = to_categorical
else:
process = lambda y: y
super().__init__('roc_auc_score', lambda y_pred, y_true: sklearn.metrics.roc_auc_score(process(y_true.cpu().numpy()), y_pred.cpu().numpy()))