wkcn commented on a change in pull request #14443: Mxnet allclose
URL: https://github.com/apache/incubator-mxnet/pull/14443#discussion_r271518449
##########
File path: python/mxnet/test_utils.py
##########
@@ -476,24 +487,70 @@ def assert_almost_equal(a, b, rtol=None, atol=None,
names=('a', 'b'), equal_nan=
Parameters
----------
- a : np.ndarray
- b : np.ndarray
- threshold : None or float
- The checking threshold. Default threshold will be used if set to
``None``.
+ a : np.ndarray or mx.nd.array
+ b : np.ndarray or mx.nd.array
+ rtol : None or float
+ The relative threshold. Default threshold will be used if set to
``None``.
+ atol : None or float
+ The absolute threshold. Default threshold will be used if set to
``None``.
+ names : tuple of names, optional
+ The names used in error message when an exception occurs
+ equal_nan : boolean, optional
+ The flag determining how to treat NAN values in comparison
"""
rtol = get_rtol(rtol)
atol = get_atol(atol)
- if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
- return
+
+ use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
+ if not use_np_allclose:
+ if not (hasattr(a, 'context') and hasattr(b, 'context') and a.context
== b.context and a.dtype == b.dtype):
+ use_np_allclose = True
+ if isinstance(a, mx.nd.NDArray):
+ a = a.asnumpy()
+ if isinstance(b, mx.nd.NDArray):
+ b = b.asnumpy()
+
+ if use_np_allclose:
+ if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
+ return
+ else:
+ output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan)
+ if output.asnumpy() == 1:
Review comment:
Thanks for your contribution! It may be better to use asscalar.
Edit:
I see. It need to call `asnumpy()` since the shape is `()`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services