This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new dd1f21b Support single array input for metric (#9930) dd1f21b is described below commit dd1f21b4369371f4d20fc8a88c1d10834f8cf53b Author: Tong He <hetong...@gmail.com> AuthorDate: Tue Mar 13 11:47:26 2018 -0700 Support single array input for metric (#9930) * fix #9865 * add unittest * fix format * fix format * fix superfluous loop in metric * fix lint --- python/mxnet/metric.py | 59 +++++++++++++++++++++++++++--------- tests/python/unittest/test_metric.py | 21 +++++++++++++ 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index ddffc01..ff4cce9 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -30,8 +30,25 @@ from . import ndarray from . import registry -def check_label_shapes(labels, preds, shape=0): - if shape == 0: +def check_label_shapes(labels, preds, wrap=False, shape=False): + """Helper function for checking shape of label and prediction + + Parameters + ---------- + labels : list of `NDArray` + The labels of the data. + + preds : list of `NDArray` + Predicted values. + + wrap : boolean + If True, wrap labels/preds in a list if they are single NDArray + + shape : boolean + If True, check the shape of labels and preds; + Otherwise only check their length. + """ + if not shape: label_shape, pred_shape = len(labels), len(preds) else: label_shape, pred_shape = labels.shape, preds.shape @@ -40,6 +57,13 @@ def check_label_shapes(labels, preds, shape=0): raise ValueError("Shape of labels {} does not match shape of " "predictions {}".format(label_shape, pred_shape)) + if wrap: + if isinstance(labels, ndarray.ndarray.NDArray): + labels = [labels] + if isinstance(preds, ndarray.ndarray.NDArray): + preds = [preds] + + return labels, preds class EvalMetric(object): """Base class for all evaluation metrics. @@ -386,7 +410,7 @@ class Accuracy(EvalMetric): Prediction values for samples. Each prediction value can either be the class index, or a vector of likelihoods for all classes. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: @@ -394,7 +418,7 @@ class Accuracy(EvalMetric): pred_label = pred_label.asnumpy().astype('int32') label = label.asnumpy().astype('int32') - check_label_shapes(label, pred_label) + labels, preds = check_label_shapes(label, pred_label) self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) @@ -456,7 +480,7 @@ class TopKAccuracy(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred_label in zip(labels, preds): assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' @@ -614,7 +638,7 @@ class F1(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): self.metrics.update_binary_stats(label, pred) @@ -785,7 +809,7 @@ class MAE(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -793,6 +817,8 @@ class MAE(EvalMetric): if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) + if len(pred.shape) == 1: + pred = pred.reshape(pred.shape[0], 1) self.sum_metric += numpy.abs(label - pred).mean() self.num_inst += 1 # numpy.prod(label.shape) @@ -843,7 +869,7 @@ class MSE(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -851,6 +877,8 @@ class MSE(EvalMetric): if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) + if len(pred.shape) == 1: + pred = pred.reshape(pred.shape[0], 1) self.sum_metric += ((label - pred)**2.0).mean() self.num_inst += 1 # numpy.prod(label.shape) @@ -901,7 +929,7 @@ class RMSE(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -909,6 +937,8 @@ class RMSE(EvalMetric): if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) + if len(pred.shape) == 1: + pred = pred.reshape(pred.shape[0], 1) self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean()) self.num_inst += 1 @@ -969,7 +999,7 @@ class CrossEntropy(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -1037,7 +1067,7 @@ class NegativeLogLikelihood(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -1095,9 +1125,10 @@ class PearsonCorrelation(EvalMetric): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) + for label, pred in zip(labels, preds): - check_label_shapes(label, pred, 1) + check_label_shapes(label, pred, False, True) label = label.asnumpy() pred = pred.asnumpy() self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] @@ -1209,7 +1240,7 @@ class CustomMetric(EvalMetric): Predicted values. """ if not self._allow_extra_outputs: - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for pred, label in zip(preds, labels): label = label.asnumpy() diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index fee8b66..bcb0e2d 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -120,6 +120,27 @@ def test_pearsonr(): _, pearsonr = metric.get() assert pearsonr == pearsonr_expected +def test_single_array_input(): + pred = mx.nd.array([[1,2,3,4]]) + label = pred + 0.1 + + mse = mx.metric.create('mse') + mse.update(label, pred) + _, mse_res = mse.get() + np.testing.assert_almost_equal(mse_res, 0.01) + + mae = mx.metric.create('mae') + mae.update(label, pred) + mae.get() + _, mae_res = mae.get() + np.testing.assert_almost_equal(mae_res, 0.1) + + rmse = mx.metric.create('rmse') + rmse.update(label, pred) + rmse.get() + _, rmse_res = rmse.get() + np.testing.assert_almost_equal(rmse_res, 0.1) + if __name__ == '__main__': import nose nose.runmodule() -- To stop receiving notification emails like this one, please contact zhash...@apache.org.