leezu commented on a change in pull request #18083:
URL: https://github.com/apache/incubator-mxnet/pull/18083#discussion_r422279290



##########
File path: tests/python/train/test_mlp.py
##########
@@ -37,8 +37,9 @@ def test_mlp(tmpdir):
 
     def accuracy(label, pred):
         py = np.argmax(pred, axis=1)
-        return np.sum(py == label) / float(label.size)
-
+        return np.sum(py == label.astype(py)) / float(label.size)
+        # currently mxnet.numpy (which used in gluon.metric) did not support 
"==" between different types        

Review comment:
       Reference https://github.com/apache/incubator-mxnet/issues/18137

##########
File path: tests/python/unittest/test_metric.py
##########
@@ -391,19 +367,23 @@ def test_single_array_input():
     pred = mx.nd.array([[1,2,3,4]])
     label = pred + 0.1
 
-    mse = mx.metric.create('mse')
+    mse = mx.gluon.metric.create('mse')
     mse.update(label, pred)
     _, mse_res = mse.get()
     np.testing.assert_almost_equal(mse_res, 0.01)
 
-    mae = mx.metric.create('mae')
+    mae = mx.gluon.metric.create('mae')
     mae.update(label, pred)
     mae.get()
     _, mae_res = mae.get()
     np.testing.assert_almost_equal(mae_res, 0.1)
 
-    rmse = mx.metric.create('rmse')
+    rmse = mx.gluon.metric.create('rmse')
     rmse.update(label, pred)
     rmse.get()
     _, rmse_res = rmse.get()
     np.testing.assert_almost_equal(rmse_res, 0.1)
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

Review comment:
       This can be deleted




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to