szha closed pull request #9930: Support single array input for metric URL: https://github.com/apache/incubator-mxnet/pull/9930
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index ddffc01bd23..ff4cce944e0 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -30,8 +30,25 @@ from . import registry -def check_label_shapes(labels, preds, shape=0): - if shape == 0: +def check_label_shapes(labels, preds, wrap=False, shape=False): + """Helper function for checking shape of label and prediction + + Parameters + ---------- + labels : list of `NDArray` + The labels of the data. + + preds : list of `NDArray` + Predicted values. + + wrap : boolean + If True, wrap labels/preds in a list if they are single NDArray + + shape : boolean + If True, check the shape of labels and preds; + Otherwise only check their length. + """ + if not shape: label_shape, pred_shape = len(labels), len(preds) else: label_shape, pred_shape = labels.shape, preds.shape @@ -40,6 +57,13 @@ def check_label_shapes(labels, preds, shape=0): raise ValueError("Shape of labels {} does not match shape of " "predictions {}".format(label_shape, pred_shape)) + if wrap: + if isinstance(labels, ndarray.ndarray.NDArray): + labels = [labels] + if isinstance(preds, ndarray.ndarray.NDArray): + preds = [preds] + + return labels, preds class EvalMetric(object): """Base class for all evaluation metrics. @@ -386,7 +410,7 @@ def update(self, labels, preds): Prediction values for samples. Each prediction value can either be the class index, or a vector of likelihoods for all classes. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: @@ -394,7 +418,7 @@ def update(self, labels, preds): pred_label = pred_label.asnumpy().astype('int32') label = label.asnumpy().astype('int32') - check_label_shapes(label, pred_label) + labels, preds = check_label_shapes(label, pred_label) self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) @@ -456,7 +480,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred_label in zip(labels, preds): assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' @@ -614,7 +638,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): self.metrics.update_binary_stats(label, pred) @@ -785,7 +809,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -793,6 +817,8 @@ def update(self, labels, preds): if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) + if len(pred.shape) == 1: + pred = pred.reshape(pred.shape[0], 1) self.sum_metric += numpy.abs(label - pred).mean() self.num_inst += 1 # numpy.prod(label.shape) @@ -843,7 +869,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -851,6 +877,8 @@ def update(self, labels, preds): if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) + if len(pred.shape) == 1: + pred = pred.reshape(pred.shape[0], 1) self.sum_metric += ((label - pred)**2.0).mean() self.num_inst += 1 # numpy.prod(label.shape) @@ -901,7 +929,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -909,6 +937,8 @@ def update(self, labels, preds): if len(label.shape) == 1: label = label.reshape(label.shape[0], 1) + if len(pred.shape) == 1: + pred = pred.reshape(pred.shape[0], 1) self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean()) self.num_inst += 1 @@ -969,7 +999,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -1037,7 +1067,7 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for label, pred in zip(labels, preds): label = label.asnumpy() @@ -1095,9 +1125,10 @@ def update(self, labels, preds): preds : list of `NDArray` Predicted values. """ - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) + for label, pred in zip(labels, preds): - check_label_shapes(label, pred, 1) + check_label_shapes(label, pred, False, True) label = label.asnumpy() pred = pred.asnumpy() self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] @@ -1209,7 +1240,7 @@ def update(self, labels, preds): Predicted values. """ if not self._allow_extra_outputs: - check_label_shapes(labels, preds) + labels, preds = check_label_shapes(labels, preds, True) for pred, label in zip(preds, labels): label = label.asnumpy() diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index fee8b66e3af..bcb0e2d9bf8 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -120,6 +120,27 @@ def test_pearsonr(): _, pearsonr = metric.get() assert pearsonr == pearsonr_expected +def test_single_array_input(): + pred = mx.nd.array([[1,2,3,4]]) + label = pred + 0.1 + + mse = mx.metric.create('mse') + mse.update(label, pred) + _, mse_res = mse.get() + np.testing.assert_almost_equal(mse_res, 0.01) + + mae = mx.metric.create('mae') + mae.update(label, pred) + mae.get() + _, mae_res = mae.get() + np.testing.assert_almost_equal(mae_res, 0.1) + + rmse = mx.metric.create('rmse') + rmse.update(label, pred) + rmse.get() + _, rmse_res = rmse.get() + np.testing.assert_almost_equal(rmse_res, 0.1) + if __name__ == '__main__': import nose nose.runmodule() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services