szha closed pull request #9930: Support single array input for metric
URL: https://github.com/apache/incubator-mxnet/pull/9930
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index ddffc01bd23..ff4cce944e0 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -30,8 +30,25 @@
 from . import registry
 
 
-def check_label_shapes(labels, preds, shape=0):
-    if shape == 0:
+def check_label_shapes(labels, preds, wrap=False, shape=False):
+    """Helper function for checking shape of label and prediction
+
+    Parameters
+    ----------
+    labels : list of `NDArray`
+        The labels of the data.
+
+    preds : list of `NDArray`
+        Predicted values.
+
+    wrap : boolean
+        If True, wrap labels/preds in a list if they are single NDArray
+
+    shape : boolean
+        If True, check the shape of labels and preds;
+        Otherwise only check their length.
+    """
+    if not shape:
         label_shape, pred_shape = len(labels), len(preds)
     else:
         label_shape, pred_shape = labels.shape, preds.shape
@@ -40,6 +57,13 @@ def check_label_shapes(labels, preds, shape=0):
         raise ValueError("Shape of labels {} does not match shape of "
                          "predictions {}".format(label_shape, pred_shape))
 
+    if wrap:
+        if isinstance(labels, ndarray.ndarray.NDArray):
+            labels = [labels]
+        if isinstance(preds, ndarray.ndarray.NDArray):
+            preds = [preds]
+
+    return labels, preds
 
 class EvalMetric(object):
     """Base class for all evaluation metrics.
@@ -386,7 +410,7 @@ def update(self, labels, preds):
             Prediction values for samples. Each prediction value can either be 
the class index,
             or a vector of likelihoods for all classes.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred_label in zip(labels, preds):
             if pred_label.shape != label.shape:
@@ -394,7 +418,7 @@ def update(self, labels, preds):
             pred_label = pred_label.asnumpy().astype('int32')
             label = label.asnumpy().astype('int32')
 
-            check_label_shapes(label, pred_label)
+            labels, preds = check_label_shapes(label, pred_label)
 
             self.sum_metric += (pred_label.flat == label.flat).sum()
             self.num_inst += len(pred_label.flat)
@@ -456,7 +480,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred_label in zip(labels, preds):
             assert(len(pred_label.shape) <= 2), 'Predictions should be no more 
than 2 dims'
@@ -614,7 +638,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             self.metrics.update_binary_stats(label, pred)
@@ -785,7 +809,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -793,6 +817,8 @@ def update(self, labels, preds):
 
             if len(label.shape) == 1:
                 label = label.reshape(label.shape[0], 1)
+            if len(pred.shape) == 1:
+                pred = pred.reshape(pred.shape[0], 1)
 
             self.sum_metric += numpy.abs(label - pred).mean()
             self.num_inst += 1 # numpy.prod(label.shape)
@@ -843,7 +869,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -851,6 +877,8 @@ def update(self, labels, preds):
 
             if len(label.shape) == 1:
                 label = label.reshape(label.shape[0], 1)
+            if len(pred.shape) == 1:
+                pred = pred.reshape(pred.shape[0], 1)
 
             self.sum_metric += ((label - pred)**2.0).mean()
             self.num_inst += 1 # numpy.prod(label.shape)
@@ -901,7 +929,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -909,6 +937,8 @@ def update(self, labels, preds):
 
             if len(label.shape) == 1:
                 label = label.reshape(label.shape[0], 1)
+            if len(pred.shape) == 1:
+                pred = pred.reshape(pred.shape[0], 1)
 
             self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean())
             self.num_inst += 1
@@ -969,7 +999,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -1037,7 +1067,7 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -1095,9 +1125,10 @@ def update(self, labels, preds):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
+
         for label, pred in zip(labels, preds):
-            check_label_shapes(label, pred, 1)
+            check_label_shapes(label, pred, False, True)
             label = label.asnumpy()
             pred = pred.asnumpy()
             self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 
1]
@@ -1209,7 +1240,7 @@ def update(self, labels, preds):
             Predicted values.
         """
         if not self._allow_extra_outputs:
-            check_label_shapes(labels, preds)
+            labels, preds = check_label_shapes(labels, preds, True)
 
         for pred, label in zip(preds, labels):
             label = label.asnumpy()
diff --git a/tests/python/unittest/test_metric.py 
b/tests/python/unittest/test_metric.py
index fee8b66e3af..bcb0e2d9bf8 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -120,6 +120,27 @@ def test_pearsonr():
     _, pearsonr = metric.get()
     assert pearsonr == pearsonr_expected
 
+def test_single_array_input():
+    pred = mx.nd.array([[1,2,3,4]])
+    label = pred + 0.1
+
+    mse = mx.metric.create('mse')
+    mse.update(label, pred)
+    _, mse_res = mse.get()
+    np.testing.assert_almost_equal(mse_res, 0.01)
+
+    mae = mx.metric.create('mae')
+    mae.update(label, pred)
+    mae.get()
+    _, mae_res = mae.get()
+    np.testing.assert_almost_equal(mae_res, 0.1)
+
+    rmse = mx.metric.create('rmse')
+    rmse.update(label, pred)
+    rmse.get()
+    _, rmse_res = rmse.get()
+    np.testing.assert_almost_equal(rmse_res, 0.1)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to