szha closed pull request #9777: [MX-9588] Add micro averaging strategy for F1 
metric
URL: https://github.com/apache/incubator-mxnet/pull/9777
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 8bb3f6ee0a..0a02b80a1c 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -475,6 +475,85 @@ def update(self, labels, preds):
             self.num_inst += num_samples
 
 
+class _BinaryClassificationMetrics(object):
+    """
+    Private container class for classification metric statistics. True/false 
positive and
+     true/false negative counts are sufficient statistics for various 
classification metrics.
+    This class provides the machinery to track those statistics across 
mini-batches of
+    (label, prediction) pairs.
+    """
+
+    def __init__(self):
+        self.true_positives = 0
+        self.false_negatives = 0
+        self.false_positives = 0
+        self.true_negatives = 0
+
+    def update_binary_stats(self, label, pred):
+        """
+        Update various binary classification counts for a single (label, pred)
+        pair.
+
+        Parameters
+        ----------
+        label : `NDArray`
+            The labels of the data.
+
+        pred : `NDArray`
+            Predicted values.
+        """
+        pred = pred.asnumpy()
+        label = label.asnumpy().astype('int32')
+        pred_label = numpy.argmax(pred, axis=1)
+
+        check_label_shapes(label, pred)
+        if len(numpy.unique(label)) > 2:
+            raise ValueError("%s currently only supports binary 
classification."
+                             % self.__class__.__name__)
+
+        for y_pred, y_true in zip(pred_label, label):
+            if y_pred == 1 and y_true == 1:
+                self.true_positives += 1.
+            elif y_pred == 1 and y_true == 0:
+                self.false_positives += 1.
+            elif y_pred == 0 and y_true == 1:
+                self.false_negatives += 1.
+            else:
+                self.true_negatives += 1.
+
+    @property
+    def precision(self):
+        if self.true_positives + self.false_positives > 0:
+            return self.true_positives / (self.true_positives + 
self.false_positives)
+        else:
+            return 0.
+
+    @property
+    def recall(self):
+        if self.true_positives + self.false_negatives > 0:
+            return self.true_positives / (self.true_positives + 
self.false_negatives)
+        else:
+            return 0.
+
+    @property
+    def fscore(self):
+        if self.precision + self.recall > 0:
+            return 2 * self.precision * self.recall / (self.precision + 
self.recall)
+        else:
+            return 0.
+
+    @property
+    def total_examples(self):
+        return self.false_negatives + self.false_positives + \
+               self.true_negatives + self.true_positives
+
+    def reset_stats(self):
+        self.false_positives = 0
+        self.false_negatives = 0
+        self.true_positives = 0
+        self.true_negatives = 0
+
+
 @register
 class F1(EvalMetric):
     """Computes the F1 score of a binary classification problem.
@@ -503,21 +582,27 @@ class F1(EvalMetric):
     label_names : list of str, or None
         Name of labels that should be used when updating with update_dict.
         By default include all labels.
+    average : str, default 'macro'
+        Strategy to be used for aggregating across mini-batches.
+            "macro": average the F1 scores for each batch.
+            "micro": compute a single F1 score across all batches.
 
     Examples
     --------
     >>> predicts = [mx.nd.array([[0.3, 0.7], [0., 1.], [0.4, 0.6]])]
     >>> labels   = [mx.nd.array([0., 1., 1.])]
-    >>> acc = mx.metric.F1()
-    >>> acc.update(preds = predicts, labels = labels)
-    >>> print acc.get()
+    >>> f1 = mx.metric.F1()
+    >>> f1.update(preds = predicts, labels = labels)
+    >>> print f1.get()
     ('f1', 0.8)
     """
 
     def __init__(self, name='f1',
-                 output_names=None, label_names=None):
-        super(F1, self).__init__(
-            name, output_names=output_names, label_names=label_names)
+                 output_names=None, label_names=None, average="macro"):
+        self.average = average
+        self.metrics = _BinaryClassificationMetrics()
+        EvalMetric.__init__(self, name=name,
+                            output_names=output_names, label_names=label_names)
 
     def update(self, labels, preds):
         """Updates the internal evaluation result.
@@ -533,41 +618,21 @@ def update(self, labels, preds):
         check_label_shapes(labels, preds)
 
         for label, pred in zip(labels, preds):
-            pred = pred.asnumpy()
-            label = label.asnumpy().astype('int32')
-            pred_label = numpy.argmax(pred, axis=1)
-
-            check_label_shapes(label, pred)
-            if len(numpy.unique(label)) > 2:
-                raise ValueError("F1 currently only supports binary 
classification.")
-
-            true_positives, false_positives, false_negatives = 0., 0., 0.
-
-            for y_pred, y_true in zip(pred_label, label):
-                if y_pred == 1 and y_true == 1:
-                    true_positives += 1.
-                elif y_pred == 1 and y_true == 0:
-                    false_positives += 1.
-                elif y_pred == 0 and y_true == 1:
-                    false_negatives += 1.
+            self.metrics.update_binary_stats(label, pred)
 
-            if true_positives + false_positives > 0:
-                precision = true_positives / (true_positives + false_positives)
-            else:
-                precision = 0.
-
-            if true_positives + false_negatives > 0:
-                recall = true_positives / (true_positives + false_negatives)
-            else:
-                recall = 0.
-
-            if precision + recall > 0:
-                f1_score = 2 * precision * recall / (precision + recall)
-            else:
-                f1_score = 0.
-
-            self.sum_metric += f1_score
+        if self.average == "macro":
+            self.sum_metric += self.metrics.fscore
             self.num_inst += 1
+            self.metrics.reset_stats()
+        else:
+            self.sum_metric = self.metrics.fscore * self.metrics.total_examples
+            self.num_inst = self.metrics.total_examples
+
+    def reset(self):
+        """Resets the internal evaluation result to initial state."""
+        self.sum_metric = 0.
+        self.num_inst = 0.
+        self.metrics.reset_stats()
 
 
 @register
diff --git a/tests/python/unittest/test_metric.py 
b/tests/python/unittest/test_metric.py
index 0f2f27f9eb..fee8b66e3a 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -26,7 +26,6 @@ def check_metric(metric, *args, **kwargs):
 
     assert metric.get_config() == metric2.get_config()
 
-
 def test_metrics():
     check_metric('acc', axis=0)
     check_metric('f1')
@@ -56,18 +55,51 @@ def test_acc():
     assert acc == expected_acc
 
 def test_f1():
-    pred = mx.nd.array([[0.3, 0.7], [1., 0], [0.4, 0.6], [0.6, 0.4], [0.9, 
0.1]])
-    label = mx.nd.array([0, 1, 1, 1, 1])
-    positives = np.argmax(pred, axis=1).sum().asscalar()
-    true_positives = (np.argmax(pred, axis=1) == label).sum().asscalar()
-    precision = true_positives / positives
-    overall_positives = label.sum().asscalar()
-    recall = true_positives / overall_positives
-    f1_expected = 2 * (precision * recall) / (precision + recall)
-    metric = mx.metric.create('f1')
-    metric.update([label], [pred])
-    _, f1 = metric.get()
-    assert f1 == f1_expected
+    microF1 = mx.metric.create("f1", average="micro")
+    macroF1 = mx.metric.F1(average="macro")
+
+    assert np.isnan(macroF1.get()[1])
+    assert np.isnan(microF1.get()[1])
+
+    # check divide by zero
+    pred = mx.nd.array([[0.9, 0.1],
+                        [0.8, 0.2]])
+    label = mx.nd.array([0, 0])
+    macroF1.update([label], [pred])
+    microF1.update([label], [pred])
+    assert macroF1.get()[1] == 0.0
+    assert microF1.get()[1] == 0.0
+    macroF1.reset()
+    microF1.reset()
+
+    pred11 = mx.nd.array([[0.1, 0.9],
+                          [0.5, 0.5]])
+    label11 = mx.nd.array([1, 0])
+    pred12 = mx.nd.array([[0.85, 0.15],
+                          [1.0, 0.0]])
+    label12 = mx.nd.array([1, 0])
+    pred21 = mx.nd.array([[0.6, 0.4]])
+    label21 = mx.nd.array([0])
+    pred22 = mx.nd.array([[0.2, 0.8]])
+    label22 = mx.nd.array([1])
+
+    microF1.update([label11, label12], [pred11, pred12])
+    macroF1.update([label11, label12], [pred11, pred12])
+    assert microF1.num_inst == 4
+    assert macroF1.num_inst == 1
+    # f1 = 2 * tp / (2 * tp + fp + fn)
+    fscore1 = 2. * (1) / (2 * 1 + 1 + 0)
+    np.testing.assert_almost_equal(microF1.get()[1], fscore1)
+    np.testing.assert_almost_equal(macroF1.get()[1], fscore1)
+
+    microF1.update([label21, label22], [pred21, pred22])
+    macroF1.update([label21, label22], [pred21, pred22])
+    assert microF1.num_inst == 6
+    assert macroF1.num_inst == 2
+    fscore2 = 2. * (1) / (2 * 1 + 0 + 0)
+    fscore_total = 2. * (1 + 1) / (2 * (1 + 1) + (1 + 0) + (0 + 0))
+    np.testing.assert_almost_equal(microF1.get()[1], fscore_total)
+    np.testing.assert_almost_equal(macroF1.get()[1], (fscore1 + fscore2) / 2.)
 
 def test_perplexity():
     pred = mx.nd.array([[0.8, 0.2], [0.2, 0.8], [0, 1.]])


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to