You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/28 23:32:12 UTC

[GitHub] piiswrong closed pull request #10524: [MXNET-312] Added Matthew's Correlation Coefficient to metrics

piiswrong closed pull request #10524: [MXNET-312] Added Matthew's Correlation Coefficient to metrics
URL: https://github.com/apache/incubator-mxnet/pull/10524
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index aa3ab44c48a..aa378cae509 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -568,6 +568,27 @@ def fscore(self):
         else:
             return 0.
 
+    @property
+    def matthewscc(self):
+        """
+        Calculate the Matthew's Correlation Coefficent
+        """
+        if not self.total_examples:
+            return 0.
+
+        true_pos = float(self.true_positives)
+        false_pos = float(self.false_positives)
+        false_neg = float(self.false_negatives)
+        true_neg = float(self.true_negatives)
+        terms = [(true_pos + false_pos),
+                 (true_pos + false_neg),
+                 (true_neg + false_pos),
+                 (true_neg + false_neg)]
+        denom = 1.
+        for t in filter(lambda t: t != 0., terms):
+            denom *= t
+        return ((true_pos * true_neg) - (false_pos * false_neg)) / math.sqrt(denom)
+
     @property
     def total_examples(self):
         return self.false_negatives + self.false_positives + \
@@ -584,7 +605,7 @@ def reset_stats(self):
 class F1(EvalMetric):
     """Computes the F1 score of a binary classification problem.
 
-    The F1 score is equivalent to weighted average of the precision and recall,
+    The F1 score is equivalent to harmonic mean of the precision and recall,
     where the best value is 1.0 and the worst value is 0.0. The formula for F1 score is::
 
         F1 = 2 * (precision * recall) / (precision + recall)
@@ -661,6 +682,107 @@ def reset(self):
         self.metrics.reset_stats()
 
 
+@register
+class MCC(EvalMetric):
+    """Computes the Matthews Correlation Coefficient of a binary classification problem.
+
+    While slower to compute than F1 the MCC can give insight that F1 or Accuracy cannot.
+    For instance, if the network always predicts the same result
+    then the MCC will immeadiately show this. The MCC is also symetric with respect
+    to positive and negative categorization, however, there needs to be both
+    positive and negative examples in the labels or it will always return 0.
+    MCC of 0 is uncorrelated, 1 is completely correlated, and -1 is negatively correlated.
+
+    .. math::
+        \\text{MCC} = \\frac{ TP \\times TN - FP \\times FN }
+        {\\sqrt{ (TP + FP) ( TP + FN ) ( TN + FP ) ( TN + FN ) } }
+
+    where 0 terms in the denominator are replaced by 1.
+
+    .. note::
+
+        This version of MCC only supports binary classification.
+
+    Parameters
+    ----------
+    name : str
+        Name of this metric instance for display.
+    output_names : list of str, or None
+        Name of predictions that should be used when updating with update_dict.
+        By default include all predictions.
+    label_names : list of str, or None
+        Name of labels that should be used when updating with update_dict.
+        By default include all labels.
+    average : str, default 'macro'
+        Strategy to be used for aggregating across mini-batches.
+            "macro": average the MCC for each batch.
+            "micro": compute a single MCC across all batches.
+
+    Examples
+    --------
+    In this example the network almost always predicts positive
+    >>> false_positives = 1000
+    >>> false_negatives = 1
+    >>> true_positives = 10000
+    >>> true_negatives = 1
+    >>> predicts = [mx.nd.array(
+        [[.3, .7]]*false_positives +
+        [[.7, .3]]*true_negatives +
+        [[.7, .3]]*false_negatives +
+        [[.3, .7]]*true_positives
+    )]
+    >>> labels  = [mx.nd.array(
+        [0.]*(false_positives + true_negatives) +
+        [1.]*(false_negatives + true_positives)
+    )]
+    >>> f1 = mx.metric.F1()
+    >>> f1.update(preds = predicts, labels = labels)
+    >>> mcc = mx.metric.MCC()
+    >>> mcc.update(preds = predicts, labels = labels)
+    >>> print f1.get()
+    ('f1', 0.95233560306652054)
+    >>> print mcc.get()
+    ('mcc', 0.01917751877733392)
+    """
+
+    def __init__(self, name='mcc',
+                 output_names=None, label_names=None, average="macro"):
+        self._average = average
+        self._metrics = _BinaryClassificationMetrics()
+        EvalMetric.__init__(self, name=name,
+                            output_names=output_names, label_names=label_names)
+
+    def update(self, labels, preds):
+        """Updates the internal evaluation result.
+
+        Parameters
+        ----------
+        labels : list of `NDArray`
+            The labels of the data.
+
+        preds : list of `NDArray`
+            Predicted values.
+        """
+        labels, preds = check_label_shapes(labels, preds, True)
+
+        for label, pred in zip(labels, preds):
+            self._metrics.update_binary_stats(label, pred)
+
+        if self._average == "macro":
+            self.sum_metric += self._metrics.matthewscc
+            self.num_inst += 1
+            self._metrics.reset_stats()
+        else:
+            self.sum_metric = self._metrics.matthewscc * self._metrics.total_examples
+            self.num_inst = self._metrics.total_examples
+
+    def reset(self):
+        """Resets the internal evaluation result to initial state."""
+        self.sum_metric = 0.
+        self.num_inst = 0.
+        self._metrics.reset_stats()
+
+
 @register
 class Perplexity(EvalMetric):
     """Computes perplexity.
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index 7bc9c10ce5b..26277d2acff 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -29,6 +29,7 @@ def check_metric(metric, *args, **kwargs):
 def test_metrics():
     check_metric('acc', axis=0)
     check_metric('f1')
+    check_metric('mcc')
     check_metric('perplexity', -1)
     check_metric('pearsonr')
     check_metric('nll_loss')
@@ -122,6 +123,54 @@ def test_f1():
     np.testing.assert_almost_equal(microF1.get()[1], fscore_total)
     np.testing.assert_almost_equal(macroF1.get()[1], (fscore1 + fscore2) / 2.)
 
+def test_mcc():
+    microMCC = mx.metric.create("mcc", average="micro")
+    macroMCC = mx.metric.MCC(average="macro")
+
+    assert np.isnan(microMCC.get()[1])
+    assert np.isnan(macroMCC.get()[1])
+
+    # check divide by zero
+    pred = mx.nd.array([[0.9, 0.1],
+                        [0.8, 0.2]])
+    label = mx.nd.array([0, 0])
+    microMCC.update([label], [pred])
+    macroMCC.update([label], [pred])
+    assert microMCC.get()[1] == 0.0
+    assert macroMCC.get()[1] == 0.0
+    microMCC.reset()
+    macroMCC.reset()
+
+    pred11 = mx.nd.array([[0.1, 0.9],
+                        [0.5, 0.5]])
+    label11 = mx.nd.array([1, 0])
+    pred12 = mx.nd.array([[0.85, 0.15],
+                        [1.0, 0.0]])
+    label12 = mx.nd.array([1, 0])
+    pred21 = mx.nd.array([[0.6, 0.4]])
+    label21 = mx.nd.array([0])
+    pred22 = mx.nd.array([[0.2, 0.8]])
+    label22 = mx.nd.array([1])
+    microMCC.update([label11, label12], [pred11, pred12])
+    macroMCC.update([label11, label12], [pred11, pred12])
+    assert microMCC.num_inst == 4
+    assert macroMCC.num_inst == 1
+    tp1 = 1; fp1 = 0; fn1 = 1; tn1=2
+    mcc1 = (tp1*tn1 - fp1*fn1) / np.sqrt((tp1+fp1)*(tp1+fn1)*(tn1+fp1)*(tn1+fn1))
+    np.testing.assert_almost_equal(microMCC.get()[1], mcc1)
+    np.testing.assert_almost_equal(macroMCC.get()[1], mcc1)
+
+    microMCC.update([label21, label22], [pred21, pred22])
+    macroMCC.update([label21, label22], [pred21, pred22])
+    assert microMCC.num_inst == 6
+    assert macroMCC.num_inst == 2
+    tp2 = 1; fp2 = 0; fn2 = 0; tn2=1
+    mcc2 = (tp2*tn2 - fp2*fn2) / np.sqrt((tp2+fp2)*(tp2+fn2)*(tn2+fp2)*(tn2+fn2))
+    tpT = tp1+tp2; fpT = fp1+fp2; fnT = fn1+fn2; tnT = tn1+tn2;
+    mccT = (tpT*tnT - fpT*fnT) / np.sqrt((tpT+fpT)*(tpT+fnT)*(tnT+fpT)*(tnT+fnT))
+    np.testing.assert_almost_equal(microMCC.get()[1], mccT)
+    np.testing.assert_almost_equal(macroMCC.get()[1], .5*(mcc1+mcc2))
+
 def test_perplexity():
     pred = mx.nd.array([[0.8, 0.2], [0.2, 0.8], [0, 1.]])
     label = mx.nd.array([0, 1, 1])


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services