This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 39ea683  AMP support for Numpy ops (#19036)
39ea683 is described below

commit 39ea6832a004c873a89b2bd59d71488bac811e59
Author: mk-61 <[email protected]>
AuthorDate: Fri Oct 2 15:37:03 2020 -0700

    AMP support for Numpy ops (#19036)
    
    * AMP support for numpy ops
    
    * Move misplaced docstring
    
    * Fix numpy submodule handling in AMP
    
    * Fix module selection when AMP-patching numpy ops
    
    * Check if a model exists in AMP init
    
    * Make AMP loss scale public
    
    * Re-enable a test
    
    * Re-disable a text, with the original reason
    
    Co-authored-by: Vladimir Cherepanov <[email protected]>
---
 python/mxnet/contrib/amp/amp.py               | 157 ++++++---
 python/mxnet/contrib/amp/lists/symbol_fp16.py | 489 +++++++++++++++-----------
 python/mxnet/contrib/amp/loss_scaler.py       |  18 +-
 python/mxnet/operator.py                      |  26 ++
 src/operator/contrib/all_finite.cc            |   2 +
 src/operator/tensor/amp_cast.cc               |   2 +
 tests/python/gpu/test_contrib_amp.py          |  69 ++--
 7 files changed, 471 insertions(+), 292 deletions(-)

diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py
index 86bf513..5fde733 100644
--- a/python/mxnet/contrib/amp/amp.py
+++ b/python/mxnet/contrib/amp/amp.py
@@ -25,10 +25,13 @@ __all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 
'convert_model',
 
 from array import array
 import ctypes
+import inspect
 import logging
 import contextlib
+import sys
 import numpy as np
 
+from mxnet import numpy
 from ... import symbol
 from ...context import gpu
 from ...symbol import Symbol
@@ -36,33 +39,35 @@ from ...symbol import contrib as symbol_contrib
 from ... import ndarray
 from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
 from . import lists
-from ...gluon import trainer
+from ...gluon import Block, trainer
 from ... import base
-from ...base import c_str_array, SymbolHandle, check_call, _LIB, mx_uint, 
c_array_buf
+from ...base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
+                     _NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
+                     c_str_array, SymbolHandle, check_call, _LIB, mx_uint, 
c_array_buf)
 from ... import optimizer as opt
 from .loss_scaler import LossScaler
+from ...operator import get_all_registered_operators_grouped
 
 bfloat16 = np.dtype([('bfloat16', np.uint16)])
 
-def _cast_symbol_NDArray(s, dtype):
-    float_types_gpu = (np.float16, np.float32)
-    float_types_cpu = (bfloat16, np.float32)
-    if isinstance(s, Symbol):
-        return symbol.amp_cast(s, dtype=dtype)
-    elif isinstance(s, NDArray):
-        if (s.dtype != dtype and s.dtype in float_types_gpu and 
s.context.device_type != 'cpu'):
-            return ndarray.amp_cast(s, dtype=dtype)
-        elif (s.dtype != dtype and s.dtype in float_types_cpu and 
s.context.device_type == 'cpu'):
-            return ndarray.amp_cast(s, dtype=dtype)
-        else:
-            return s
-    else:
-        return s
+float_types_gpu = (np.float16, np.float32)
+float_types_cpu = (bfloat16, np.float32)
 
-def _get_fun_to_wrap(name, module, submodule_dict):
+def _cast_symbol_NDArray(s, dtype, is_numpy_module=False):
+    if isinstance(s, Symbol):
+        amp_cast = symbol.numpy._internal.amp_cast if is_numpy_module else 
symbol.amp_cast
+        return amp_cast(s, dtype=dtype)
+    if isinstance(s, NDArray):
+        amp_cast = ndarray.numpy._internal.amp_cast if is_numpy_module else 
ndarray.amp_cast
+        if s.dtype != dtype and (s.dtype in float_types_gpu and 
s.context.device_type != 'cpu' or
+                                 s.dtype in float_types_cpu and 
s.context.device_type == 'cpu'):
+            return amp_cast(s, dtype=dtype)
+    return s
+
+def _get_nd_fun_to_wrap(name, module, submodule_dict):
     module_internal = getattr(module, "_internal")
     prefix = base._get_op_name_prefix(name)
-    if len(prefix) > 0:
+    if prefix:
         if prefix != '_random_' or name.endswith('_like'):
             func_name = name[len(prefix):]
             cur_module = submodule_dict[prefix]
@@ -77,8 +82,26 @@ def _get_fun_to_wrap(name, module, submodule_dict):
         cur_module = module
     return func_name, cur_module
 
-def _wrap_symbol_functions(module, target_dtype, target_precision_ops=None,
-                           conditional_fp32_ops=None, fp32_ops=None):
+def _get_np_fun_to_wrap(name, ns_prefix):
+    for pre, mod, subs in ((_NP_OP_PREFIX, 'numpy', _NP_OP_SUBMODULE_LIST),
+                           (_NP_EXT_OP_PREFIX, 'numpy_extension', 
_NP_EXT_OP_SUBMODULE_LIST),
+                           (_NP_INTERNAL_OP_PREFIX, 'numpy._internal', [])):
+        if name.startswith(pre):
+            name = name[len(pre):]
+            for sub in subs:
+                if name.startswith(sub):
+                    return name[len(sub):], 
sys.modules[f'{ns_prefix}.{mod}.{sub[1:-1]}']
+            return name, sys.modules[f'{ns_prefix}.{mod}']
+    assert False
+    return None  # for pylint
+
+def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases, 
get_cond_aliases,
+                           get_fun_to_wrap, target_precision_ops=None, 
conditional_fp32_ops=None,
+                           fp32_ops=None):
+
+    nd_mod = ndarray.numpy._internal if is_numpy_module else ndarray
+    sy_mod = symbol.numpy._internal if is_numpy_module else symbol
+
     def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
         def _new_fun(*args, **kwargs):
             if cond_arg is not None:
@@ -91,9 +114,10 @@ def _wrap_symbol_functions(module, target_dtype, 
target_precision_ops=None,
                     if fp32_param[i]:
                         new_args.append(x)
                     else:
-                        new_args.append(_cast_symbol_NDArray(x, target_dtype))
+                        new_args.append(_cast_symbol_NDArray(x, target_dtype, 
is_numpy_module))
             else:
-                new_args = list(map(lambda x: _cast_symbol_NDArray(x, 
target_dtype), args))
+                new_args = list(map(
+                    lambda x: _cast_symbol_NDArray(x, target_dtype, 
is_numpy_module), args))
             args = tuple(new_args)
             if fp32_param:
                 new_kwargs = {}
@@ -101,10 +125,11 @@ def _wrap_symbol_functions(module, target_dtype, 
target_precision_ops=None,
                     if k in fp32_param:
                         new_kwargs[k] = v
                     else:
-                        new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype)
+                        new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype, 
is_numpy_module)
                     kwargs = new_kwargs
             else:
-                kwargs = {k: _cast_symbol_NDArray(v, target_dtype) for k, v in 
kwargs.items()}
+                kwargs = {k: _cast_symbol_NDArray(v, target_dtype, 
is_numpy_module)
+                          for k, v in kwargs.items()}
             return f(*args, **kwargs)
         _new_fun.__name__ = f.__name__
         _new_fun.__module__ = f.__module__
@@ -126,10 +151,10 @@ def _wrap_symbol_functions(module, target_dtype, 
target_precision_ops=None,
                     if (x.name in aux) or fp32_param[i]:
                         new_inputs.append(x)
                     else:
-                        new_inputs.append(_cast_symbol_NDArray(x, 
target_dtype))
+                        new_inputs.append(_cast_symbol_NDArray(x, 
target_dtype, is_numpy_module))
                 inputs = new_inputs
             else:
-                inputs = list(map(lambda x: _cast_symbol_NDArray(x, 
target_dtype)
+                inputs = list(map(lambda x: _cast_symbol_NDArray(x, 
target_dtype, is_numpy_module)
                                   if x.name not in aux else x, inputs))
             atomic_sym = sym._gen_atomic_symbol()
             wrapped_sym = atomic_sym(*inputs)
@@ -162,11 +187,11 @@ def _wrap_symbol_functions(module, target_dtype, 
target_precision_ops=None,
                             widest_type = np.float32
                 for arr, index, arg in symbols:
                     if arg.dtype != widest_type and arg.dtype == target_dtype:
-                        arr[index] = ndarray.amp_cast(arg, dtype=widest_type)
+                        arr[index] = nd_mod.amp_cast(arg, dtype=widest_type)
             else:
                 # Symbol case
                 sym_to_check = list(map(lambda x: x[2], symbols))
-                casted_syms = symbol.amp_multicast(*sym_to_check, 
num_outputs=len(sym_to_check))
+                casted_syms = sy_mod.amp_multicast(*sym_to_check, 
num_outputs=len(sym_to_check))
                 symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]),
                                    zip(symbols, casted_syms)))
                 for arr, index, arg in symbols:
@@ -180,54 +205,50 @@ def _wrap_symbol_functions(module, target_dtype, 
target_precision_ops=None,
 
     _wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib) 
else _ndarray_wrapper
 
-    submodule_dict = {}
-    for op_name_prefix in base._OP_NAME_PREFIX_LIST:
-        submodule_dict[op_name_prefix] =\
-                getattr(module, op_name_prefix[1:-1])
     fp32_param_list = list_lp16_use_fp32_params(target_dtype)
     wrap_list = target_precision_ops if target_precision_ops is not None \
                     else list_lp16_ops(target_dtype)
-    for fun_name in wrap_list:
+    for fun_name in get_aliases(wrap_list):
         try:
-            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, 
submodule_dict)
+            fun_name, cur_module = get_fun_to_wrap(fun_name, module)
             f_to_wrap = getattr(cur_module, fun_name)
             fp32_param = fp32_param_list[fun_name] if (fp32_param_list and 
fun_name in fp32_param_list) else None
             setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype, 
fp32_param=fp32_param))
-            if cur_module == module:
+            if not is_numpy_module and cur_module == module:
                 setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype, 
fp32_param=fp32_param))
         except AttributeError:
             raise
 
     wrap_list = fp32_ops if fp32_ops is not None else 
list_fp32_ops(target_dtype)
-    for fun_name in wrap_list:
+    for fun_name in get_aliases(wrap_list):
         try:
-            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, 
submodule_dict)
+            fun_name, cur_module = get_fun_to_wrap(fun_name, module)
             f_to_wrap = getattr(cur_module, fun_name)
             setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
-            if cur_module == module:
+            if not is_numpy_module and cur_module == module:
                 setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32))
         except AttributeError:
             raise
 
     wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
                     else list_conditional_fp32_ops(target_dtype)
-    for fun_name, arg, arg_values in wrap_list:
+    for fun_name, arg, arg_values in get_cond_aliases(wrap_list):
         try:
-            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, 
submodule_dict)
+            fun_name, cur_module = get_fun_to_wrap(fun_name, module)
             f_to_wrap = getattr(cur_module, fun_name)
             setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, 
cond_arg=(arg, arg_values)))
-            if cur_module == module:
+            if not is_numpy_module and cur_module == module:
                 setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32, 
cond_arg=(arg, arg_values)))
         except AttributeError:
             raise
 
 
-    for fun_name in list_widest_type_cast(target_dtype):
+    for fun_name in get_aliases(list_widest_type_cast(target_dtype)):
         try:
-            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, 
submodule_dict)
+            fun_name, cur_module = get_fun_to_wrap(fun_name, module)
             f_to_wrap = getattr(cur_module, fun_name)
             setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
-            if cur_module == module:
+            if not is_numpy_module and cur_module == module:
                 setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap))
         except AttributeError:
             raise
@@ -278,6 +299,14 @@ def scale_loss(loss, optimizer_or_trainer):
     else:
         yield optimizer_or_trainer._amp_loss_scaler.loss_scale * loss
 
+def warn_if_model_exists():
+    for f in inspect.stack():
+        for k, v in f.frame.f_locals.items():
+            if isinstance(v, Block):
+                logging.warning('Block %s created in [%s:%d] before AMP init.',
+                                k, f.filename, f.lineno)
+                return
+
 def init(target_dtype='float16', target_precision_ops=None,
          conditional_fp32_ops=None, fp32_ops=None):
     """Initialize AMP (automatic mixed precision).
@@ -310,13 +339,39 @@ def init(target_dtype='float16', 
target_precision_ops=None,
             target_dtype = bfloat16
         else:
             target_dtype = np.dtype(target_dtype)
-        _wrap_symbol_functions(symbol, target_dtype, target_precision_ops,
-                               conditional_fp32_ops, fp32_ops)
-        _wrap_symbol_functions(ndarray, target_dtype, target_precision_ops,
-                               conditional_fp32_ops, fp32_ops)
+
+        warn_if_model_exists()
+
+        ops = get_all_registered_operators_grouped()
+        get_aliases_nd = lambda l: [a for op in l for a in ops[op] if not 
base._is_np_op(a)]
+        get_aliases_np = lambda l: [a for op in l for a in ops[op] if 
base._is_np_op(a)]
+        get_aliases_np_pub = lambda l: [a for op in l for a in ops[op]
+                                        if a.startswith(('_np_', '_npx_'))]
+        get_cond_aliases_nd = lambda l: [(a, *rest) for op, *rest in l for a 
in ops[op]
+                                         if not base._is_np_op(a)]
+        get_cond_aliases_np = lambda l: [(a, *rest) for op, *rest in l for a 
in ops[op]
+                                         if base._is_np_op(a)]
+        get_cond_aliases_np_pub = lambda l: [(a, *rest) for op, *rest in l for 
a in ops[op]
+                                             if a.startswith(('_np_', 
'_npx_'))]
+        sy_submodules = {p:getattr(symbol, p[1:-1]) for p in 
base._OP_NAME_PREFIX_LIST}
+        get_sy_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod, 
sy_submodules)
+        nd_submodules = {p:getattr(ndarray, p[1:-1]) for p in 
base._OP_NAME_PREFIX_LIST}
+        get_nd_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod, 
nd_submodules)
+        get_np_sy_fun = lambda fun, mod: _get_np_fun_to_wrap(fun, 
"mxnet.symbol")
+        get_np_nd_fun = lambda fun, mod: _get_np_fun_to_wrap(fun, 
"mxnet.ndarray")
+        get_np_fun = lambda fun, mode: _get_np_fun_to_wrap(fun, "mxnet")
+        todo = [
+            (symbol, False, get_aliases_nd, get_cond_aliases_nd, get_sy_fun),
+            (ndarray, False, get_aliases_nd, get_cond_aliases_nd, get_nd_fun),
+            (symbol.numpy, True, get_aliases_np, get_cond_aliases_np, 
get_np_sy_fun),
+            (ndarray.numpy, True, get_aliases_np, get_cond_aliases_np, 
get_np_nd_fun),
+            (numpy, True, get_aliases_np_pub, get_cond_aliases_np_pub, 
get_np_fun),
+        ]
         _loss_scaler = LossScaler()
-        _wrap_loss_output_functions(ndarray, _loss_scaler, target_dtype)
-        _wrap_loss_output_functions(symbol, _loss_scaler, target_dtype)
+        for module, is_numpy, get_aliases, get_cond_aliases, get_fun in todo:
+            _wrap_module_functions(module, is_numpy, target_dtype, 
get_aliases, get_cond_aliases,
+                                   get_fun, target_precision_ops, 
conditional_fp32_ops, fp32_ops)
+            _wrap_loss_output_functions(module, _loss_scaler, target_dtype)
 
 def init_trainer(optimizer_or_trainer):
     """Initialize trainer or optimizer to work with AMP dynamic loss scaling.
@@ -339,8 +394,8 @@ def init_trainer(optimizer_or_trainer):
     if isinstance(optimizer_or_trainer, trainer.Trainer):
         optimizer_or_trainer._amp_loss_scaler = loss_scaler
         optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
+        trainer.Trainer.amp_loss_scale = property(lambda self: 
self._amp_loss_scaler.loss_scale)
     elif isinstance(optimizer_or_trainer, opt.Optimizer):
-        # TODO(ptredak): make it work with the optimizer
         raise TypeError("AMP is currently only compatible with Gluon Trainer")
     else:
         raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
diff --git a/python/mxnet/contrib/amp/lists/symbol_fp16.py 
b/python/mxnet/contrib/amp/lists/symbol_fp16.py
index 275d088..db608e4 100644
--- a/python/mxnet/contrib/amp/lists/symbol_fp16.py
+++ b/python/mxnet/contrib/amp/lists/symbol_fp16.py
@@ -18,8 +18,15 @@
 # coding: utf-8
 """Lists of functions whitelisted/blacklisted for automatic mixed precision in 
symbol API."""
 
+from ....runtime import Features
+
+
 # Functions that should be cast to lower precision
 FP16_FUNCS = [
+    '_linalg_gemm',
+    '_linalg_gemm2',
+    '_npi_einsum',
+    '_npi_matmul',
     'Convolution',
     'Deconvolution',
     'FullyConnected',
@@ -35,18 +42,29 @@ FP16_FP32_FUNCS = [
     'BilinearSampler',
     'BlockGrad',
     'Cast',
-    'cast',
     'cast_storage',
+    '_contrib_BatchNormWithReLU',
+    '_contrib_allclose',
+    '_contrib_arange_like',
+    '_contrib_dynamic_reshape',
+    '_contrib_intgemm_fully_connected',
+    '_contrib_intgemm_maxabsolute',
+    '_contrib_intgemm_prepare_data',
+    '_contrib_intgemm_prepare_weight',
+    '_contrib_intgemm_take_weight',
+    '_contrib_quantized_batch_norm',
+    '_contrib_quantized_elemwise_mul',
+    '_contrib_quantized_embedding',
+    '_contrib_mrcnn_mask_target',
+    '_contrib_round_ste',
+    '_contrib_sign_ste',
     'Crop',
     'Dropout',
     'Embedding',
-    '_sparse_Embedding',
-    '_sparse_FullyConnected',
     'Flatten',
     'GridGenerator',
     'Pad',
     'Pooling',
-    'Pooling_v1',
     'ROIPooling',
     'Reshape',
     'SequenceLast',
@@ -57,31 +75,15 @@ FP16_FP32_FUNCS = [
     'SwapAxis',
     'UpSampling',
     '_CachedOp',
+    '_CachedOpThreadSafe',
     '_CrossDeviceCopy',
     '_CustomFunction',
-    '_DivScalar',
-    '_EqualScalar',
-    '_GreaterScalar',
-    '_GreaterEqualScalar',
-    '_LesserScalar',
-    '_LesserEqualScalar',
-    '_LogicalAndScalar',
-    '_LogicalOrScalar',
-    '_LogicalXorScalar',
-    '_MaximumScalar',
-    '_MinimumScalar',
-    '_MinusScalar',
-    '_ModScalar',
-    '_MulScalar',
+    '_FusedOp',
+    '_FusedOpHelper',
+    '_FusedOpOutHelper',
     '_NoGradient',
-    '_NotEqualScalar',
-    '_PlusScalar',
-    '_RMinusScalar',
-    '_RModScalar',
     '_adamw_update',
-    '_add',
     '_arange',
-    '_broadcast_backward',
     '_cond',
     '_contrib_AdaptiveAvgPooling2D',
     '_contrib_BilinearResize2D',
@@ -109,8 +111,6 @@ FP16_FP32_FUNCS = [
     '_contrib_requantize',
     '_copy',
     '_copyto',
-    '_crop_assign',
-    '_crop_assign_scalar',
     '_cvcopyMakeBorder',
     '_cvimdecode',
     '_cvimread',
@@ -125,6 +125,7 @@ FP16_FP32_FUNCS = [
     '_greater_scalar',
     '_greater_equal_scalar',
     '_histogram',
+    '_hypot_scalar',
     '_identity_with_attr_like_rhs',
     '_image_adjust_lighting',
     '_image_flip_left_right',
@@ -133,6 +134,8 @@ FP16_FP32_FUNCS = [
     '_image_random_brightness',
     '_image_random_color_jitter',
     '_image_random_contrast',
+    '_image_random_crop',
+    '_image_random_resized_crop',
     '_image_random_flip_left_right',
     '_image_random_flip_top_bottom',
     '_image_random_hue',
@@ -152,7 +155,162 @@ FP16_FP32_FUNCS = [
     '_mod_scalar',
     '_mp_adamw_update',
     '_mul_scalar',
+    '_multi_adamw_update',
+    '_multi_lamb_update',
+    '_multi_lans_update',
+    '_multi_mp_adamw_update',
+    '_multi_mp_lamb_update',
+    '_multi_mp_lans_update',
     '_not_equal_scalar',
+    '_np_reshape',
+    '_npi_absolute',
+    '_npi_add',
+    '_npi_add_scalar',
+    '_npi_advanced_indexing',
+    '_npi_advanced_indexing_multiple',
+    '_npi_all',
+    '_npi_any',
+    '_npi_arange',
+    '_npi_arccosh',
+    '_npi_arcsinh',
+    '_npi_arctan',
+    '_npi_arctan2',
+    '_npi_arctan2_scalar',
+    '_npi_argmax',
+    '_npi_argmin',
+    '_npi_around',
+    '_npi_atleast_1d',
+    '_npi_atleast_2d',
+    '_npi_atleast_3d',
+    '_npi_bernoulli',
+    '_npi_bincount',
+    '_npi_bitwise_and',
+    '_npi_bitwise_and_scalar',
+    '_npi_bitwise_not',
+    '_npi_bitwise_or',
+    '_npi_bitwise_or_scalar',
+    '_npi_bitwise_xor',
+    '_npi_bitwise_xor_scalar',
+    '_npi_blackman',
+    '_npi_boolean_mask_assign_scalar',
+    '_npi_boolean_mask_assign_tensor',
+    '_npi_broadcast_to',
+    '_npi_cbrt',
+    '_npi_ceil',
+    '_npi_choice',
+    '_npi_copy',
+    '_npi_copysign_scalar',
+    '_npi_cos',
+    '_npi_degrees',
+    '_npi_delete',
+    '_npi_diag',
+    '_npi_diag_indices_from',
+    '_npi_diagflat',
+    '_npi_diagonal',
+    '_npi_diff',
+    '_npi_dsplit',
+    '_npi_equal_scalar',
+    '_npi_exponential',
+    '_npi_eye',
+    '_npi_fill_diagonal',
+    '_npi_fix',
+    '_npi_flip',
+    '_npi_floor',
+    '_npi_fmax_scalar',
+    '_npi_fmin_scalar',
+    '_npi_fmod_scalar',
+    '_npi_full',
+    '_npi_full_like',
+    '_npi_gamma',
+    '_npi_greater_equal_scalar',
+    '_npi_greater_scalar',
+    '_npi_gumbel',
+    '_npi_hamming',
+    '_npi_hanning',
+    '_npi_hsplit',
+    '_npi_identity',
+    '_npi_indices',
+    '_npi_insert_scalar',
+    '_npi_insert_slice',
+    '_npi_insert_tensor',
+    '_npi_interp',
+    '_npi_isinf',
+    '_npi_isfinite',
+    '_npi_isnan',
+    '_npi_isneginf',
+    '_npi_isposinf',
+    '_npi_laplace',
+    '_npi_less_equal_scalar',
+    '_npi_less_scalar',
+    '_npi_logistic',
+    '_npi_lcm',
+    '_npi_lcm_scalar',
+    '_npi_linspace',
+    '_npi_logical_not',
+    '_npi_logical_and_scalar',
+    '_npi_logical_or_scalar',
+    '_npi_logical_xor_scalar',
+    '_npi_logspace',
+    '_npi_max',
+    '_npi_min',
+    '_npi_mod',
+    '_npi_mod_scalar',
+    '_npi_moveaxis',
+    '_npi_multinomial',
+    '_npi_multiply',
+    '_npi_multiply_scalar',
+    '_npi_nan_to_num',
+    '_npi_negative',
+    '_npi_normal',
+    '_npi_normal_n',
+    '_npi_not_equal_scalar',
+    '_npi_ones',
+    '_npi_pad',
+    '_npi_pareto',
+    '_npi_percentile',
+    '_npi_powerd',
+    '_npi_radians',
+    '_npi_rarctan2_scalar',
+    '_npi_rayleigh',
+    '_npi_rcopysign_scalar',
+    '_npi_repeats',
+    '_npi_rfmod_scalar',
+    '_npi_rint',
+    '_npi_rmod_scalar',
+    '_npi_roll',
+    '_npi_rollaxis',
+    '_npi_rot90',
+    '_npi_rsubtract_scalar',
+    '_npi_rtrue_divide_scalar',
+    '_npi_share_memory',
+    '_npi_sign',
+    '_npi_sin',
+    '_npi_sqrt',
+    '_npi_squeeze',
+    '_npi_subtract',
+    '_npi_subtract_scalar',
+    '_npi_tanh',
+    '_npi_transpose',
+    '_npi_tri',
+    '_npi_tril',
+    '_npi_tril_indices',
+    '_npi_triu',
+    '_npi_true_divide',
+    '_npi_true_divide_scalar',
+    '_npi_trunc',
+    '_npi_uniform',
+    '_npi_uniform_n',
+    '_npi_unique',
+    '_npi_weibull',
+    '_npi_where_lscalar',
+    '_npi_where_rscalar',
+    '_npi_where_scalar2',
+    '_npi_zeros',
+    '_npx_constraint_check',
+    '_npx_nonzero',
+    '_npx_relu',
+    '_npx_reshape',
+    '_npx_sigmoid',
     '_onehot_encode',
     '_ones',
     '_plus_scalar',
@@ -189,42 +347,9 @@ FP16_FP32_FUNCS = [
     '_shuffle',
     '_slice_assign',
     '_slice_assign_scalar',
-    '_sparse_abs',
     '_sparse_adagrad_update',
-    '_sparse_adam_update',
-    '_sparse_arccosh',
-    '_sparse_arcsinh',
-    '_sparse_arctan',
-    '_sparse_cast_storage',
-    '_sparse_cbrt',
-    '_sparse_ceil',
-    '_sparse_clip',
-    '_sparse_concat',
-    '_sparse_cos',
-    '_sparse_degrees',
-    '_sparse_fix',
-    '_sparse_floor',
-    '_sparse_ftrl_update',
-    '_sparse_negative',
-    '_sparse_radians',
-    '_sparse_relu',
     '_sparse_retain',
-    '_sparse_rint',
-    '_sparse_round',
-    '_sparse_sgd_mom_update',
-    '_sparse_sgd_update',
-    '_sparse_sigmoid',
-    '_sparse_sign',
-    '_sparse_sin',
-    '_sparse_sinh',
-    '_sparse_slice',
-    '_sparse_sqrt',
-    '_sparse_stop_gradient',
-    '_sparse_tanh',
-    '_sparse_trunc',
-    '_sparse_zeros_like',
     '_split_v2',
-    '_split_v2_backward',
     '_unravel_index',
     '_zeros',
     '_zeros_without_dtype',
@@ -240,16 +365,14 @@ FP16_FP32_FUNCS = [
     'argmax_channel',
     'argmin',
     'batch_take',
-    'broadcast_axes',
     'broadcast_axis',
     'broadcast_like',
     'broadcast_to',
     'cbrt',
     'ceil',
-    'choose_element_0index',
     'clip',
+    'col2im',
     'cos',
-    'crop',
     'degrees',
     'depth_to_space',
     'diag',
@@ -257,64 +380,52 @@ FP16_FP32_FUNCS = [
     'expand_dims',
     'fill_element_0index',
     'fix',
-    'flatten',
-    'flip',
     'floor',
     'ftml_update',
     'ftrl_update',
     'gather_nd',
     'hard_sigmoid',
-    'identity',
+    'im2col',
+    'lamb_update_phase1',
+    'lamb_update_phase2',
     'logical_not',
-    'max_axis',
     'max',
     'min',
-    'min_axis',
+    'mp_lamb_update_phase1',
+    'mp_lamb_update_phase2',
+    'mp_nag_mom_update',
     'mp_sgd_mom_update',
     'mp_sgd_update',
     'multi_all_finite',
+    'multi_lars',
     'multi_mp_sgd_mom_update',
     'multi_mp_sgd_update',
     'multi_sgd_mom_update',
     'multi_sgd_update',
+    'multi_sum_sq',
+    'nag_mom_update',
     'negative',
-    'normal',
     'one_hot',
     'ones_like',
-    'pad',
     'pick',
+    'preloaded_multi_mp_sgd_mom_update',
+    'preloaded_multi_mp_sgd_update',
+    'preloaded_multi_sgd_mom_update',
+    'preloaded_multi_sgd_update',
     'radians',
-    'random_exponential',
-    'random_gamma',
-    'random_generalized_negative_binomial',
-    'random_negative_binomial',
-    'random_normal',
-    'random_poisson',
-    'random_randint',
-    'random_uniform',
-    'ravel_multi_index',
     'relu',
     'repeat',
-    'reshape',
+    'reset_arrays',
     'reshape_like',
     'reverse',
     'rint',
     'rmsprop_update',
     'rmspropalex_update',
     'round',
-    'sample_exponential',
-    'sample_gamma',
-    'sample_generalized_negative_binomial',
-    'sample_multinomial',
-    'sample_negative_binomial',
-    'sample_normal',
-    'sample_poisson',
-    'sample_uniform',
     'scatter_nd',
     'sgd_mom_update',
     'sgd_update',
     'shape_array',
-    'shuffle',
     'sigmoid',
     'sign',
     'signsgd_update',
@@ -327,51 +438,55 @@ FP16_FP32_FUNCS = [
     'softsign',
     'sort',
     'space_to_depth',
-    'split',
     'sqrt',
     'squeeze',
-    'stop_gradient',
-    'swapaxes',
     'take',
     'tanh',
     'tile',
     'transpose',
     'trunc',
-    'uniform',
-    'unravel_index',
     'zeros_like',
-    '_sg_mkldnn_conv',
-    '_sg_mkldnn_fully_connected',
-    'CuDNNBatchNorm',
-    '_TensorRT',
     ]
 
+if Features().is_enabled('CUDNN'):
+    FP16_FP32_FUNCS.extend([
+        'CuDNNBatchNorm',
+    ])
+
 # Functions that have to be cast to FP32 due to possible
 # overflows
 FP32_FUNCS = [
-    'Convolution_v1',
     'IdentityAttachKLSparseReg',
     'arccos',
-    '_sparse_arccos',
     'arcsin',
     'cosh',
-    '_sparse_cosh',
     'erfinv',
     'sinh',
     'tan',
-    '_sparse_tan',
     'arctanh',
-    '_sparse_arcsin',
-    '_sparse_arctanh',
+    '_contrib_calibrate_entropy',
     '_contrib_MultiBoxDetection',
     '_contrib_MultiBoxPrior',
     '_contrib_MultiBoxTarget',
+    '_npi_arccos',
+    '_npi_arcsin',
+    '_npi_arctanh',
+    '_npi_cosh',
+    '_npi_sinh',
+    '_npi_tan',
 
     # Exponents
+    '_npi_exp',
+    '_npi_expm1',
+    '_npi_ldexp',
+    '_npi_ldexp_scalar',
+    '_npi_log',
+    '_npi_log10',
+    '_npi_log1p',
+    '_npi_log2',
+    '_npi_rldexp_scalar',
     'exp',
     'expm1',
-    '_sparse_exp',
-    '_sparse_expm1',
     'log',
     'log10',
     'log2',
@@ -380,30 +495,32 @@ FP32_FUNCS = [
     # Powers
     'broadcast_power',
     'square',
-    '_sparse_square',
     'reciprocal',
-    '_RDivScalar',
     '_rdiv_scalar',
     'rsqrt',
     'rcbrt',
-    '_Power',
-    '_PowerScalar',
     '_power',
     '_power_scalar',
-    '_RPowerScalar',
     '_rpower_scalar',
-    'linalg_sumlogdiag',
-    '_Hypot',
-    '_HypotScalar',
-    '_hypot',
-    '_hypot_scalar',
-    'broadcast_hypot',
     '_square_sum',
     '_contrib_hawkesll',
+    '_npi_power',
+    '_npi_power_scalar',
+    '_npi_reciprocal',
+    '_npi_rpower_scalar',
+    '_npi_square',
 
     # Reductions
+    '_npi_average',
+    '_npi_cumsum',
+    '_npi_mean',
+    '_npi_polyval',
+    '_npi_prod',
+    '_npi_std',
+    '_npi_sum',
+    '_npi_trace',
+    '_npi_var',
     'sum',
-    'sum_axis',
     'nansum',
     'prod',
     'nanprod',
@@ -414,11 +531,26 @@ FP32_FUNCS = [
     'moments',
 
     # Misc
+    '_npi_cholesky',
+    '_npi_eig',
+    '_npi_eigh',
+    '_npi_eigvals',
+    '_npi_eigvalsh',
+    '_npi_lstsq',
+    '_npi_matrix_rank',
+    '_npi_matrix_rank_none_tol',
+    '_npi_norm',
+    '_npi_pinv',
+    '_npi_pinv_scalar_rcond',
+    '_npi_qr',
+    '_npi_solve',
+    '_npi_svd',
+    '_npi_tensorinv',
+    '_npi_tensorsolve',
+    'digamma',
     'gamma',
     'gammaln',
     '_linalg_gelqf',
-    '_linalg_gemm',
-    '_linalg_gemm2',
     '_linalg_potrf',
     '_linalg_potri',
     '_linalg_sumlogdiag',
@@ -433,42 +565,16 @@ FP32_FUNCS = [
     '_linalg_inverse',
     '_linalg_det',
     '_linalg_slogdet',
-    'linalg_syrk',
-    'linalg_potrf',
-    'linalg_potri',
-    'linalg_gemm2',
-    'linalg_gemm',
-    'linalg_gelqf',
-    'linalg_trmm',
-    'linalg_trsm',
-    'linalg_makediag',
-    'linalg_extractdiag',
-    'linalg_maketrian',
-    'linalg_extracttrian',
-    'linalg_inverse',
-    'linalg_det',
-    'linalg_slogdet',
     '_NDArray',
     '_Native',
     '_contrib_count_sketch',
     '_contrib_SyncBatchNorm',
     '_contrib_fft',
-    '_sparse_gamma',
-    '_sparse_gammaln',
-    '_sparse_log',
-    '_sparse_log10',
-    '_sparse_log1p',
-    '_sparse_log2',
-    '_sparse_make_loss',
-    '_sparse_mean',
-    '_sparse_norm',
-    '_sparse_rsqrt',
     'argsort',
     'topk',
 
     # Neural network
     'softmax',
-    'Softmax',
     'log_softmax',
     'InstanceNorm',
     'LayerNorm',
@@ -482,13 +588,17 @@ FP32_FUNCS = [
     'make_loss',
     'Custom',
     'CTCLoss',
-    '_contrib_CTCLoss',
-    '_contrib_ctc_loss',
-    'ctc_loss',
     '_npx_deformable_convolution',
+    '_npx_modulated_deformable_convolution',
     '_contrib_DeformablePSROIPooling',
     ]
 
+if Features().is_enabled('MKLDNN'):
+    FP32_FUNCS.extend([
+        '_sg_mkldnn_conv',
+        '_sg_mkldnn_fully_connected',
+    ])
+
 # Functions that have to be cast to FP32 only for
 # some values of their parameters
 CONDITIONAL_FP32_FUNCS = [
@@ -499,51 +609,59 @@ CONDITIONAL_FP32_FUNCS = [
 # Functions with multiple inputs, that need the same
 # type of all their inputs
 WIDEST_TYPE_CASTS = [
-    '_Plus',
-    '_plus',
-    '_Minus',
-    '_sub',
-    '_Mul',
-    '_Div',
-    '_div',
-    '_Mod',
-    '_Not_Equal',
-    '_Equal',
     '_equal',
-    '_Greater',
     '_greater',
-    '_Greater_Equal',
     '_greater_equal',
-    '_Lesser',
-    '_Lesser_Equal',
+    '_hypot',
     '_lesser',
     '_lesser_equal',
-    '_Logical_And',
-    '_Logical_Or',
-    '_Logical_Xor',
     '_logical_and',
     '_logical_or',
     '_logical_xor',
     '_maximum',
     '_minimum',
-    '_minus',
     '_mod',
-    '_mul',
     '_not_equal',
+    '_npi_column_stack',
+    '_npi_concatenate',
+    '_npi_copysign',
+    '_npi_cross',
+    '_npi_dot',
+    '_npi_ediff1d',
+    '_npi_equal',
+    '_npi_fmax',
+    '_npi_fmin',
+    '_npi_fmod',
+    '_npi_greater',
+    '_npi_greater_equal',
+    '_npi_hypot',
+    '_npi_kron',
+    '_npi_less',
+    '_npi_less_equal',
+    '_npi_logical_and',
+    '_npi_logical_or',
+    '_npi_logical_xor',
+    '_npi_not_equal',
+    '_npi_dstack',
+    '_npi_hstack',
+    '_npi_stack',
+    '_npi_tensordot',
+    '_npi_tensordot_int_axes',
+    '_npi_vstack',
+    '_npi_where',
+    '_npx_index_add',
+    '_npx_index_update',
     'Concat',
-    'concat',
+    '_contrib_RROIAlign',
     'Correlation',
-    'ElementWiseSum',
-    '_sparse_ElementWiseSum',
     'add_n',
-    '_sparse_add_n',
     'batch_dot',
     'broadcast_add',
-    'broadcast_plus',
     'broadcast_div',
     'broadcast_equal',
     'broadcast_greater',
     'broadcast_greater_equal',
+    'broadcast_hypot',
     'broadcast_lesser',
     'broadcast_lesser_equal',
     'broadcast_logical_and',
@@ -551,7 +669,6 @@ WIDEST_TYPE_CASTS = [
     'broadcast_logical_xor',
     'broadcast_maximum',
     'broadcast_minimum',
-    'broadcast_minus',
     'broadcast_mod',
     'broadcast_mul',
     'broadcast_not_equal',
@@ -562,15 +679,14 @@ WIDEST_TYPE_CASTS = [
     'elemwise_mul',
     'elemwise_sub',
     'stack',
-    '_Maximum',
-    '_Minimum',
     '_contrib_MultiProposal',
     '_contrib_PSROIPooling',
     '_contrib_Proposal',
     '_contrib_ROIAlign',
+    '_contrib_box_decode',
+    '_contrib_box_encode',
     '_contrib_box_iou',
     '_contrib_box_nms',
-    '_contrib_box_non_maximum_suppression',
     '_contrib_dgl_adjacency',
     '_contrib_dgl_csr_neighbor_non_uniform_sample',
     '_contrib_dgl_csr_neighbor_uniform_sample',
@@ -582,28 +698,7 @@ WIDEST_TYPE_CASTS = [
     '_contrib_interleaved_matmul_selfatt_qk',
     '_contrib_interleaved_matmul_selfatt_valatt',
     'where',
-    '_sparse_where',
-    '_sparse_broadcast_add',
-    '_sparse_broadcast_div',
-    '_sparse_broadcast_minus',
-    '_sparse_broadcast_mul',
-    '_sparse_broadcast_plus',
-    '_sparse_broadcast_sub',
-    '_sparse_dot',
-    '_sparse_elemwise_add',
-    '_sparse_elemwise_div',
-    '_sparse_elemwise_mul',
-    '_sparse_elemwise_sub',
-    '_sparse_sum',
 
-    'random_pdf_gamma',
-    'random_pdf_exponential',
-    'random_pdf_uniform',
-    'random_pdf_negative_binomial',
-    'random_pdf_generalized_negative_binomial',
-    'random_pdf_dirichlet',
-    'random_pdf_normal',
-    'random_pdf_poisson',
     '_random_pdf_gamma',
     '_random_pdf_exponential',
     '_random_pdf_uniform',
diff --git a/python/mxnet/contrib/amp/loss_scaler.py 
b/python/mxnet/contrib/amp/loss_scaler.py
index 3a177ce..771408e 100644
--- a/python/mxnet/contrib/amp/loss_scaler.py
+++ b/python/mxnet/contrib/amp/loss_scaler.py
@@ -19,9 +19,9 @@
 """Dynamic loss scaler for AMP."""
 import logging
 
-from ...ndarray import multi_all_finite
-from ...ndarray import ndarray as nd
 from ... import autograd as ag
+from ... import ndarray
+from ...util import is_np_array
 
 class LossScaler(object):
     """Dynamic loss scaler for AMP.
@@ -44,15 +44,21 @@ class LossScaler(object):
 
     def has_overflow(self, params):
         """Check gradients for overflow."""
+        if is_np_array():
+            all_finite_f = ndarray.numpy._internal.multi_all_finite
+            ones_f = ndarray.numpy.ones
+        else:
+            all_finite_f = ndarray.multi_all_finite
+            ones_f = ndarray.ones
         with ag.pause():
             chunk_size = 200
             valid_params = [p._grad[0] for p in params if p._grad is not None]
-            gpu_output = nd.ones((1,), ctx=valid_params[0].context)
+            gpu_output = ones_f((1,), ctx=valid_params[0].context)
             nb_params = len(valid_params)
             for idx in range(0, nb_params, chunk_size):
-                multi_all_finite(*valid_params[idx:idx+chunk_size],
-                                 
num_arrays=len(valid_params[idx:idx+chunk_size]),
-                                 init_output=False, out=gpu_output)
+                all_finite_f(*valid_params[idx:idx+chunk_size],
+                             num_arrays=len(valid_params[idx:idx+chunk_size]),
+                             init_output=False, out=gpu_output)
         has_overflow = not bool(gpu_output.asnumpy())
         self._loss_scale = self._next_loss_scale
         if has_overflow:
diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index 04efbee..b34fe0d 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -1143,6 +1143,32 @@ def get_all_registered_operators():
     return mx_registered_operator_names
 
 
+def get_all_registered_operators_grouped():
+    """Get all registered MXNet operator names, grouped by 'original' operator.
+
+    Returns
+    -------
+    names : a dictionary, mapping op name to the list of all its aliases 
(including the original).
+    """
+    ret = {}
+    for aname in get_all_registered_operators():
+        op_handle = OpHandle()
+        check_call(_LIB.NNGetOpHandle(c_str(aname), ctypes.byref(op_handle)))
+        name = ctypes.c_char_p()
+        desc = ctypes.c_char_p()
+        num_args = mx_uint()
+        arg_names = ctypes.POINTER(ctypes.c_char_p)()
+        arg_types = ctypes.POINTER(ctypes.c_char_p)()
+        arg_descs = ctypes.POINTER(ctypes.c_char_p)()
+        ret_types = ctypes.POINTER(ctypes.c_char_p)()
+        check_call(_LIB.NNGetOpInfo(op_handle, ctypes.byref(name), 
ctypes.byref(desc),
+                                    ctypes.byref(num_args), 
ctypes.byref(arg_names),
+                                    ctypes.byref(arg_types), 
ctypes.byref(arg_descs),
+                                    ctypes.byref(ret_types)))
+        ret.setdefault(py_str(name.value), []).append(aname)
+    return ret
+
+
 OperatorArguments = collections.namedtuple('OperatorArguments', ['narg', 
'names', 'types'])
 
 
diff --git a/src/operator/contrib/all_finite.cc 
b/src/operator/contrib/all_finite.cc
index 5e77510..8ef3671 100755
--- a/src/operator/contrib/all_finite.cc
+++ b/src/operator/contrib/all_finite.cc
@@ -97,6 +97,7 @@ inline void MultiAllFiniteCPU(const nnvm::NodeAttrs& attrs,
 DMLC_REGISTER_PARAMETER(AllFiniteParam);
 
 NNVM_REGISTER_OP(all_finite)
+.add_alias("_npi_all_finite")
 .describe(R"code(Check if all the float numbers in the array are finite (used 
for AMP)
 )code" ADD_FILELINE)
 .set_num_inputs(1)
@@ -129,6 +130,7 @@ NNVM_REGISTER_OP(all_finite)
 DMLC_REGISTER_PARAMETER(MultiAllFiniteParam);
 
 NNVM_REGISTER_OP(multi_all_finite)
+.add_alias("_npi_multi_all_finite")
 .describe(R"code(Check if all the float numbers in all the arrays are finite 
(used for AMP)
 )code" ADD_FILELINE)
 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index 7690783..0acb699 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -115,6 +115,7 @@ inline static bool AMPMultiCastStorageType(const 
nnvm::NodeAttrs& attrs, const i
 #endif  // MXNET_USE_MKLDNN == 1
 
 NNVM_REGISTER_OP(amp_cast)
+.add_alias("_npi_amp_cast")
 .describe(R"code(Cast function between low precision float/FP32 used by AMP.
 
 It casts only between low precision float/FP32 and does not do anything for 
other types.
@@ -158,6 +159,7 @@ NNVM_REGISTER_OP(_backward_amp_cast)
 .set_attr<FCompute>("FCompute<cpu>", AMPCastCompute<cpu>);
 
 NNVM_REGISTER_OP(amp_multicast)
+.add_alias("_npi_amp_multicast")
 .describe(R"code(Cast function used by AMP, that casts its inputs to the 
common widest type.
 
 It casts only between low precision float/FP32 and does not do anything for 
other types.
diff --git a/tests/python/gpu/test_contrib_amp.py 
b/tests/python/gpu/test_contrib_amp.py
index d7a6e80..6895723 100644
--- a/tests/python/gpu/test_contrib_amp.py
+++ b/tests/python/gpu/test_contrib_amp.py
@@ -29,6 +29,7 @@ from mxnet.test_utils import set_default_context, 
same_symbol_structure
 from mxnet.gluon.model_zoo.vision import get_model
 from mxnet.gluon import SymbolBlock, nn, rnn
 from mxnet.contrib.amp import amp
+from mxnet.operator import get_all_registered_operators_grouped
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import with_seed, teardown_module, 
assert_raises_cudnn_not_satisfied
@@ -42,7 +43,6 @@ def amp_tests(request):
 
     request.addfinalizer(teardown)
 
[email protected](reason='Error during waitall(). Tracked in #18099')
 def test_amp_coverage(amp_tests):
     conditional = [item[0] for item in 
amp.lists.symbol_fp16.CONDITIONAL_FP32_FUNCS]
 
@@ -66,42 +66,36 @@ def test_amp_coverage(amp_tests):
     assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP 
list."
 
     # Check the coverage
-    py_str = lambda x: x.decode('utf-8')
-
-    plist = ctypes.POINTER(ctypes.c_char_p)()
-    size = ctypes.c_uint()
-
-    mx.base._LIB.MXListAllOpNames(ctypes.byref(size),
-                                     ctypes.byref(plist))
-    op_names = []
-    for i in range(size.value):
-        s = py_str(plist[i])
-        if not s.startswith("_backward") \
-           and not s.startswith("_contrib_backward_"):
-            op_names.append(s)
-
-    ret1 = set(op_names) - set(t)
-
-    if ret1 != set():
-        warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists 
(in "
-                       "python/mxnet/contrib/amp/lists/symbol_fp16.py) - 
please add them. "
-                       """Please follow these guidelines for choosing a proper 
list:
-                       - if your operator is not to be used in a computational 
graph
-                         (e.g. image manipulation operators, optimizers) or 
does not have
-                         inputs, put it in FP16_FP32_FUNCS list,
-                       - if your operator requires FP32 inputs or is not safe 
to use with lower
-                         precision, put it in FP32_FUNCS list,
-                       - if your operator supports both FP32 and lower 
precision, has
-                         multiple inputs and expects all inputs to be of the 
same
-                         type, put it in WIDEST_TYPE_CASTS list,
-                       - if your operator supports both FP32 and lower 
precision and has
-                         either a single input or supports inputs of different 
type,
-                         put it in FP16_FP32_FUNCS list,
-                       - if your operator is both safe to use in lower 
precision and
-                         it is highly beneficial to use it in lower precision, 
then
-                         put it in FP16_FUNCS (this is unlikely for new 
operators)
-                       - If you are not sure which list to choose, FP32_FUNCS 
is the
-                         safest option""")
+    covered = set(t)
+    ops = get_all_registered_operators_grouped()
+    required = set(k for k in ops
+                   if not k.startswith(("_backward", "_contrib_backward", 
"_npi_backward")) and
+                   not k.endswith("_backward"))
+
+    extra = covered - required
+    assert not extra, f"{len(extra)} operators are not needed in the AMP 
lists: {sorted(extra)}"
+
+    guidelines = """Please follow these guidelines for choosing a proper list:
+    - if your operator is not to be used in a computational graph
+      (e.g. image manipulation operators, optimizers) or does not have
+      inputs, put it in FP16_FP32_FUNCS list,
+    - if your operator requires FP32 inputs or is not safe to use with lower
+      precision, put it in FP32_FUNCS list,
+    - if your operator supports both FP32 and lower precision, has
+      multiple inputs and expects all inputs to be of the same
+      type, put it in WIDEST_TYPE_CASTS list,
+    - if your operator supports both FP32 and lower precision and has
+      either a single input or supports inputs of different type,
+      put it in FP16_FP32_FUNCS list,
+    - if your operator is both safe to use in lower precision and
+      it is highly beneficial to use it in lower precision, then
+      put it in FP16_FUNCS (this is unlikely for new operators)
+    - If you are not sure which list to choose, FP32_FUNCS is the
+                     safest option"""
+    diff = required - covered
+    assert not diff, f"{len(diff)} operators {sorted(diff)} do not exist in 
AMP lists (in " \
+        f"python/mxnet/contrib/amp/lists/symbol_fp16.py) - please add them. " \
+        f"\n{guidelines}"
 
 @with_seed()
 @pytest.mark.skip(reason='Error during waitall(). Tracked in #18099')
@@ -120,7 +114,6 @@ def test_amp_conversion_rnn(amp_tests):
 
 
 @with_seed()
[email protected](reason='Error during waitall(). Tracked in #18099')
 def test_fp16_casting(amp_tests):
     data = mx.sym.var("data")
     out1 = mx.sym.amp_cast(data, dtype="float16")

Reply via email to