leezu commented on a change in pull request #18083:
URL: https://github.com/apache/incubator-mxnet/pull/18083#discussion_r412625767



##########
File path: python/mxnet/gluon/metric.py
##########
@@ -619,90 +583,100 @@ def update_binary_stats(self, label, pred):
         """
         pred = pred.asnumpy()
         label = label.asnumpy().astype('int32')
-        pred_label = numpy.argmax(pred, axis=1)
-
-        check_label_shapes(label, pred)
-        if len(numpy.unique(label)) > 2:
-            raise ValueError("%s currently only supports binary 
classification."
-                             % self.__class__.__name__)
+        if self.class_type == "binary":
+            self._set(1)
+            if len(numpy.unique(label)) > 2:
+                raise ValueError("Wrong label for binary classification.")
+            if pred.shape == label.shape:
+                pass
+            elif pred.shape[-1] > 2:
+                raise ValueError("The shape of prediction {} is wrong for 
binary classification.".format(pred.shape))
+            elif pred.shape[-1] == 2:
+                pred = pred.reshape(-1, 2)[:, 1]     
+            pred_label = predict_with_threshold(pred, self.threshold).flat
+            label = label.flat
+            
+        elif self.class_type == "multiclass":
+            num = pred.shape[-1]
+            self._set(num)
+            assert label.max() < num, "pred contains fewer classes than label!"
+            pred_label = one_hot(pred.argmax(axis=-1).reshape(-1), num)        
 
+            label = one_hot(label.reshape(-1), num)
+            
+        elif self.class_type == "multilabel":
+            num = pred.shape[-1]
+            self._set(num)
+            assert pred.shape == label.shape, "The shape of label should be 
same as that of prediction for multilabel classification."
+            pred_label = predict_with_threshold(pred, 
self.threshold).reshape(-1, num)
+            label = label.reshape(-1, num)
+        else:
+            raise ValueError("Wrong class_type {}! Only supports ['binary', 
'multiclass', 'multilabel']".format(self.class_type))
+            
+        check_label_shapes(label, pred_label)
+        
         pred_true = (pred_label == 1)
         pred_false = 1 - pred_true
         label_true = (label == 1)
         label_false = 1 - label_true
 
-        true_pos = (pred_true * label_true).sum()
-        false_pos = (pred_true * label_false).sum()
-        false_neg = (pred_false * label_true).sum()
-        true_neg = (pred_false * label_false).sum()
+        true_pos = (pred_true * label_true).sum(0)
+        false_pos = (pred_true * label_false).sum(0)
+        false_neg = (pred_false * label_true).sum(0)
+        true_neg = (pred_false * label_false).sum(0)
         self.true_positives += true_pos
-        self.global_true_positives += true_pos
         self.false_positives += false_pos
-        self.global_false_positives += false_pos
         self.false_negatives += false_neg
-        self.global_false_negatives += false_neg
         self.true_negatives += true_neg
-        self.global_true_negatives += true_neg
 
     @property
     def precision(self):
-        if self.true_positives + self.false_positives > 0:
-            return float(self.true_positives) / (self.true_positives + 
self.false_positives)
+        if self.num_classes is not None:
+            return self.true_positives / numpy.maximum(self.true_positives + 
self.false_positives, 1e-12)
         else:
             return 0.
 
     @property
     def global_precision(self):
-        if self.global_true_positives + self.global_false_positives > 0:
-            return float(self.global_true_positives) / 
(self.global_true_positives + self.global_false_positives)
+        if self.num_classes is not None:
+            return self.true_positives.sum() / 
numpy.maximum(self.true_positives.sum() + self.false_positives.sum(), 1e-12)
         else:
             return 0.
-
+            
     @property
     def recall(self):
-        if self.true_positives + self.false_negatives > 0:
-            return float(self.true_positives) / (self.true_positives + 
self.false_negatives)
+        if self.num_classes is not None:
+            return self.true_positives / numpy.maximum(self.true_positives + 
self.false_negatives, 1e-12)
         else:
             return 0.
 
     @property
     def global_recall(self):
-        if self.global_true_positives + self.global_false_negatives > 0:
-            return float(self.global_true_positives) / 
(self.global_true_positives + self.global_false_negatives)
+        if self.num_classes is not None:
+            return self.true_positives.sum() / 
numpy.maximum(self.true_positives.sum() + self.false_negatives.sum(), 1e-12)
         else:
             return 0.
-
+            
     @property
     def fscore(self):
-        if self.precision + self.recall > 0:
-            return 2 * self.precision * self.recall / (self.precision + 
self.recall)
-        else:
-            return 0.
+        return (1 + self.beta ** 2) * self.precision * self.recall / 
numpy.maximum(self.beta ** 2 * self.precision + self.recall, 1e-12)
 
     @property
     def global_fscore(self):

Review comment:
       Would it make sense to adjust the name?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to