szha closed pull request #11749: [MXNET-8230] test_operator_gpu.test_rms fails
URL: https://github.com/apache/incubator-mxnet/pull/11749
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/tests/python/unittest/test_optimizer.py
b/tests/python/unittest/test_optimizer.py
index a5b3d4047df..fdf7d279d9c 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -835,8 +835,7 @@ def update(self, index, weight, grad, state):
if self.clip_weights:
mx.ndarray.clip(weight, -self.clip_weights, self.clip_weights,
out=weight)
[email protected]("Test fails intermittently. Temporarily disabled until fixed.
Tracked at https://github.com/apache/incubator-mxnet/issues/8230")
-@with_seed(0)
+@with_seed()
def test_rms():
opt1 = PyRMSProp
opt2 = mx.optimizer.RMSProp
@@ -848,6 +847,9 @@ def test_rms():
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
for dtype in [np.float16, np.float32]:
+ # Reduce foating point compare tolerance to avoid flaky test failure.
+ rtol, atol = (1e-1, 1e-1) if dtype is np.float16 else (1e-2, 1e-2)
+
for cw_option in cw_options:
for cg_option in cg_options:
for center_option in center_options:
@@ -865,9 +867,9 @@ def test_rms():
('multi_precision' not in kwarg or
not kwarg['multi_precision'])):
continue
- compare_optimizer(opt1(**kwarg),
opt2(**kwarg), shape, dtype)
+ compare_optimizer(opt1(**kwarg),
opt2(**kwarg), shape, dtype, rtol=rtol, atol=atol)
if (default_context() == mx.cpu()):
- compare_optimizer(opt1(**kwarg),
opt2(**kwarg), shape, dtype, g_stype='row_sparse')
+ compare_optimizer(opt1(**kwarg),
opt2(**kwarg), shape, dtype, g_stype='row_sparse', rtol=rtol, atol=atol)
class PyFtrl(mx.optimizer.Optimizer):
"""The Ftrl optimizer.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services