This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new dd1f21b  Support single array input for metric (#9930)
dd1f21b is described below

commit dd1f21b4369371f4d20fc8a88c1d10834f8cf53b
Author: Tong He <hetong...@gmail.com>
AuthorDate: Tue Mar 13 11:47:26 2018 -0700

    Support single array input for metric (#9930)
    
    * fix #9865
    
    * add unittest
    
    * fix format
    
    * fix format
    
    * fix superfluous loop in metric
    
    * fix lint
---
 python/mxnet/metric.py               | 59 +++++++++++++++++++++++++++---------
 tests/python/unittest/test_metric.py | 21 +++++++++++++
 2 files changed, 66 insertions(+), 14 deletions(-)

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index ddffc01..ff4cce9 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -30,8 +30,25 @@ from . import ndarray
 from . import registry
 
 
-def check_label_shapes(labels, preds, shape=0):
-    if shape == 0:
+def check_label_shapes(labels, preds, wrap=False, shape=False):
+    """Helper function for checking shape of label and prediction
+
+    Parameters
+    ----------
+    labels : list of `NDArray`
+        The labels of the data.
+
+    preds : list of `NDArray`
+        Predicted values.
+
+    wrap : boolean
+        If True, wrap labels/preds in a list if they are single NDArray
+
+    shape : boolean
+        If True, check the shape of labels and preds;
+        Otherwise only check their length.
+    """
+    if not shape:
         label_shape, pred_shape = len(labels), len(preds)
     else:
         label_shape, pred_shape = labels.shape, preds.shape
@@ -40,6 +57,13 @@ def check_label_shapes(labels, preds, shape=0):
         raise ValueError("Shape of labels {} does not match shape of "
                          "predictions {}".format(label_shape, pred_shape))
 
+    if wrap:
+        if isinstance(labels, ndarray.ndarray.NDArray):
+            labels = [labels]
+        if isinstance(preds, ndarray.ndarray.NDArray):
+            preds = [preds]
+
+    return labels, preds
 
 class EvalMetric(object):
     """Base class for all evaluation metrics.
@@ -386,7 +410,7 @@ class Accuracy(EvalMetric):
             Prediction values for samples. Each prediction value can either be 
the class index,
             or a vector of likelihoods for all classes.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred_label in zip(labels, preds):
             if pred_label.shape != label.shape:
@@ -394,7 +418,7 @@ class Accuracy(EvalMetric):
             pred_label = pred_label.asnumpy().astype('int32')
             label = label.asnumpy().astype('int32')
 
-            check_label_shapes(label, pred_label)
+            labels, preds = check_label_shapes(label, pred_label)
 
             self.sum_metric += (pred_label.flat == label.flat).sum()
             self.num_inst += len(pred_label.flat)
@@ -456,7 +480,7 @@ class TopKAccuracy(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred_label in zip(labels, preds):
             assert(len(pred_label.shape) <= 2), 'Predictions should be no more 
than 2 dims'
@@ -614,7 +638,7 @@ class F1(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             self.metrics.update_binary_stats(label, pred)
@@ -785,7 +809,7 @@ class MAE(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -793,6 +817,8 @@ class MAE(EvalMetric):
 
             if len(label.shape) == 1:
                 label = label.reshape(label.shape[0], 1)
+            if len(pred.shape) == 1:
+                pred = pred.reshape(pred.shape[0], 1)
 
             self.sum_metric += numpy.abs(label - pred).mean()
             self.num_inst += 1 # numpy.prod(label.shape)
@@ -843,7 +869,7 @@ class MSE(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -851,6 +877,8 @@ class MSE(EvalMetric):
 
             if len(label.shape) == 1:
                 label = label.reshape(label.shape[0], 1)
+            if len(pred.shape) == 1:
+                pred = pred.reshape(pred.shape[0], 1)
 
             self.sum_metric += ((label - pred)**2.0).mean()
             self.num_inst += 1 # numpy.prod(label.shape)
@@ -901,7 +929,7 @@ class RMSE(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -909,6 +937,8 @@ class RMSE(EvalMetric):
 
             if len(label.shape) == 1:
                 label = label.reshape(label.shape[0], 1)
+            if len(pred.shape) == 1:
+                pred = pred.reshape(pred.shape[0], 1)
 
             self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean())
             self.num_inst += 1
@@ -969,7 +999,7 @@ class CrossEntropy(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -1037,7 +1067,7 @@ class NegativeLogLikelihood(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
 
         for label, pred in zip(labels, preds):
             label = label.asnumpy()
@@ -1095,9 +1125,10 @@ class PearsonCorrelation(EvalMetric):
         preds : list of `NDArray`
             Predicted values.
         """
-        check_label_shapes(labels, preds)
+        labels, preds = check_label_shapes(labels, preds, True)
+
         for label, pred in zip(labels, preds):
-            check_label_shapes(label, pred, 1)
+            check_label_shapes(label, pred, False, True)
             label = label.asnumpy()
             pred = pred.asnumpy()
             self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 
1]
@@ -1209,7 +1240,7 @@ class CustomMetric(EvalMetric):
             Predicted values.
         """
         if not self._allow_extra_outputs:
-            check_label_shapes(labels, preds)
+            labels, preds = check_label_shapes(labels, preds, True)
 
         for pred, label in zip(preds, labels):
             label = label.asnumpy()
diff --git a/tests/python/unittest/test_metric.py 
b/tests/python/unittest/test_metric.py
index fee8b66..bcb0e2d 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -120,6 +120,27 @@ def test_pearsonr():
     _, pearsonr = metric.get()
     assert pearsonr == pearsonr_expected
 
+def test_single_array_input():
+    pred = mx.nd.array([[1,2,3,4]])
+    label = pred + 0.1
+
+    mse = mx.metric.create('mse')
+    mse.update(label, pred)
+    _, mse_res = mse.get()
+    np.testing.assert_almost_equal(mse_res, 0.01)
+
+    mae = mx.metric.create('mae')
+    mae.update(label, pred)
+    mae.get()
+    _, mae_res = mae.get()
+    np.testing.assert_almost_equal(mae_res, 0.1)
+
+    rmse = mx.metric.create('rmse')
+    rmse.update(label, pred)
+    rmse.get()
+    _, rmse_res = rmse.get()
+    np.testing.assert_almost_equal(rmse_res, 0.1)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
zhash...@apache.org.

Reply via email to