leezu commented on a change in pull request #18083: URL: https://github.com/apache/incubator-mxnet/pull/18083#discussion_r422279290
########## File path: tests/python/train/test_mlp.py ########## @@ -37,8 +37,9 @@ def test_mlp(tmpdir): def accuracy(label, pred): py = np.argmax(pred, axis=1) - return np.sum(py == label) / float(label.size) - + return np.sum(py == label.astype(py)) / float(label.size) + # currently mxnet.numpy (which used in gluon.metric) did not support "==" between different types Review comment: Reference https://github.com/apache/incubator-mxnet/issues/18137 ########## File path: tests/python/unittest/test_metric.py ########## @@ -391,19 +367,23 @@ def test_single_array_input(): pred = mx.nd.array([[1,2,3,4]]) label = pred + 0.1 - mse = mx.metric.create('mse') + mse = mx.gluon.metric.create('mse') mse.update(label, pred) _, mse_res = mse.get() np.testing.assert_almost_equal(mse_res, 0.01) - mae = mx.metric.create('mae') + mae = mx.gluon.metric.create('mae') mae.update(label, pred) mae.get() _, mae_res = mae.get() np.testing.assert_almost_equal(mae_res, 0.1) - rmse = mx.metric.create('rmse') + rmse = mx.gluon.metric.create('rmse') rmse.update(label, pred) rmse.get() _, rmse_res = rmse.get() np.testing.assert_almost_equal(rmse_res, 0.1) + +if __name__ == '__main__': + import nose + nose.runmodule() Review comment: This can be deleted ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org