This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 39ea683 AMP support for Numpy ops (#19036)
39ea683 is described below
commit 39ea6832a004c873a89b2bd59d71488bac811e59
Author: mk-61 <[email protected]>
AuthorDate: Fri Oct 2 15:37:03 2020 -0700
AMP support for Numpy ops (#19036)
* AMP support for numpy ops
* Move misplaced docstring
* Fix numpy submodule handling in AMP
* Fix module selection when AMP-patching numpy ops
* Check if a model exists in AMP init
* Make AMP loss scale public
* Re-enable a test
* Re-disable a text, with the original reason
Co-authored-by: Vladimir Cherepanov <[email protected]>
---
python/mxnet/contrib/amp/amp.py | 157 ++++++---
python/mxnet/contrib/amp/lists/symbol_fp16.py | 489 +++++++++++++++-----------
python/mxnet/contrib/amp/loss_scaler.py | 18 +-
python/mxnet/operator.py | 26 ++
src/operator/contrib/all_finite.cc | 2 +
src/operator/tensor/amp_cast.cc | 2 +
tests/python/gpu/test_contrib_amp.py | 69 ++--
7 files changed, 471 insertions(+), 292 deletions(-)
diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py
index 86bf513..5fde733 100644
--- a/python/mxnet/contrib/amp/amp.py
+++ b/python/mxnet/contrib/amp/amp.py
@@ -25,10 +25,13 @@ __all__ = ['init', 'init_trainer', 'scale_loss', 'unscale',
'convert_model',
from array import array
import ctypes
+import inspect
import logging
import contextlib
+import sys
import numpy as np
+from mxnet import numpy
from ... import symbol
from ...context import gpu
from ...symbol import Symbol
@@ -36,33 +39,35 @@ from ...symbol import contrib as symbol_contrib
from ... import ndarray
from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from . import lists
-from ...gluon import trainer
+from ...gluon import Block, trainer
from ... import base
-from ...base import c_str_array, SymbolHandle, check_call, _LIB, mx_uint,
c_array_buf
+from ...base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
+ _NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
+ c_str_array, SymbolHandle, check_call, _LIB, mx_uint,
c_array_buf)
from ... import optimizer as opt
from .loss_scaler import LossScaler
+from ...operator import get_all_registered_operators_grouped
bfloat16 = np.dtype([('bfloat16', np.uint16)])
-def _cast_symbol_NDArray(s, dtype):
- float_types_gpu = (np.float16, np.float32)
- float_types_cpu = (bfloat16, np.float32)
- if isinstance(s, Symbol):
- return symbol.amp_cast(s, dtype=dtype)
- elif isinstance(s, NDArray):
- if (s.dtype != dtype and s.dtype in float_types_gpu and
s.context.device_type != 'cpu'):
- return ndarray.amp_cast(s, dtype=dtype)
- elif (s.dtype != dtype and s.dtype in float_types_cpu and
s.context.device_type == 'cpu'):
- return ndarray.amp_cast(s, dtype=dtype)
- else:
- return s
- else:
- return s
+float_types_gpu = (np.float16, np.float32)
+float_types_cpu = (bfloat16, np.float32)
-def _get_fun_to_wrap(name, module, submodule_dict):
+def _cast_symbol_NDArray(s, dtype, is_numpy_module=False):
+ if isinstance(s, Symbol):
+ amp_cast = symbol.numpy._internal.amp_cast if is_numpy_module else
symbol.amp_cast
+ return amp_cast(s, dtype=dtype)
+ if isinstance(s, NDArray):
+ amp_cast = ndarray.numpy._internal.amp_cast if is_numpy_module else
ndarray.amp_cast
+ if s.dtype != dtype and (s.dtype in float_types_gpu and
s.context.device_type != 'cpu' or
+ s.dtype in float_types_cpu and
s.context.device_type == 'cpu'):
+ return amp_cast(s, dtype=dtype)
+ return s
+
+def _get_nd_fun_to_wrap(name, module, submodule_dict):
module_internal = getattr(module, "_internal")
prefix = base._get_op_name_prefix(name)
- if len(prefix) > 0:
+ if prefix:
if prefix != '_random_' or name.endswith('_like'):
func_name = name[len(prefix):]
cur_module = submodule_dict[prefix]
@@ -77,8 +82,26 @@ def _get_fun_to_wrap(name, module, submodule_dict):
cur_module = module
return func_name, cur_module
-def _wrap_symbol_functions(module, target_dtype, target_precision_ops=None,
- conditional_fp32_ops=None, fp32_ops=None):
+def _get_np_fun_to_wrap(name, ns_prefix):
+ for pre, mod, subs in ((_NP_OP_PREFIX, 'numpy', _NP_OP_SUBMODULE_LIST),
+ (_NP_EXT_OP_PREFIX, 'numpy_extension',
_NP_EXT_OP_SUBMODULE_LIST),
+ (_NP_INTERNAL_OP_PREFIX, 'numpy._internal', [])):
+ if name.startswith(pre):
+ name = name[len(pre):]
+ for sub in subs:
+ if name.startswith(sub):
+ return name[len(sub):],
sys.modules[f'{ns_prefix}.{mod}.{sub[1:-1]}']
+ return name, sys.modules[f'{ns_prefix}.{mod}']
+ assert False
+ return None # for pylint
+
+def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases,
get_cond_aliases,
+ get_fun_to_wrap, target_precision_ops=None,
conditional_fp32_ops=None,
+ fp32_ops=None):
+
+ nd_mod = ndarray.numpy._internal if is_numpy_module else ndarray
+ sy_mod = symbol.numpy._internal if is_numpy_module else symbol
+
def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
def _new_fun(*args, **kwargs):
if cond_arg is not None:
@@ -91,9 +114,10 @@ def _wrap_symbol_functions(module, target_dtype,
target_precision_ops=None,
if fp32_param[i]:
new_args.append(x)
else:
- new_args.append(_cast_symbol_NDArray(x, target_dtype))
+ new_args.append(_cast_symbol_NDArray(x, target_dtype,
is_numpy_module))
else:
- new_args = list(map(lambda x: _cast_symbol_NDArray(x,
target_dtype), args))
+ new_args = list(map(
+ lambda x: _cast_symbol_NDArray(x, target_dtype,
is_numpy_module), args))
args = tuple(new_args)
if fp32_param:
new_kwargs = {}
@@ -101,10 +125,11 @@ def _wrap_symbol_functions(module, target_dtype,
target_precision_ops=None,
if k in fp32_param:
new_kwargs[k] = v
else:
- new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype)
+ new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype,
is_numpy_module)
kwargs = new_kwargs
else:
- kwargs = {k: _cast_symbol_NDArray(v, target_dtype) for k, v in
kwargs.items()}
+ kwargs = {k: _cast_symbol_NDArray(v, target_dtype,
is_numpy_module)
+ for k, v in kwargs.items()}
return f(*args, **kwargs)
_new_fun.__name__ = f.__name__
_new_fun.__module__ = f.__module__
@@ -126,10 +151,10 @@ def _wrap_symbol_functions(module, target_dtype,
target_precision_ops=None,
if (x.name in aux) or fp32_param[i]:
new_inputs.append(x)
else:
- new_inputs.append(_cast_symbol_NDArray(x,
target_dtype))
+ new_inputs.append(_cast_symbol_NDArray(x,
target_dtype, is_numpy_module))
inputs = new_inputs
else:
- inputs = list(map(lambda x: _cast_symbol_NDArray(x,
target_dtype)
+ inputs = list(map(lambda x: _cast_symbol_NDArray(x,
target_dtype, is_numpy_module)
if x.name not in aux else x, inputs))
atomic_sym = sym._gen_atomic_symbol()
wrapped_sym = atomic_sym(*inputs)
@@ -162,11 +187,11 @@ def _wrap_symbol_functions(module, target_dtype,
target_precision_ops=None,
widest_type = np.float32
for arr, index, arg in symbols:
if arg.dtype != widest_type and arg.dtype == target_dtype:
- arr[index] = ndarray.amp_cast(arg, dtype=widest_type)
+ arr[index] = nd_mod.amp_cast(arg, dtype=widest_type)
else:
# Symbol case
sym_to_check = list(map(lambda x: x[2], symbols))
- casted_syms = symbol.amp_multicast(*sym_to_check,
num_outputs=len(sym_to_check))
+ casted_syms = sy_mod.amp_multicast(*sym_to_check,
num_outputs=len(sym_to_check))
symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]),
zip(symbols, casted_syms)))
for arr, index, arg in symbols:
@@ -180,54 +205,50 @@ def _wrap_symbol_functions(module, target_dtype,
target_precision_ops=None,
_wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib)
else _ndarray_wrapper
- submodule_dict = {}
- for op_name_prefix in base._OP_NAME_PREFIX_LIST:
- submodule_dict[op_name_prefix] =\
- getattr(module, op_name_prefix[1:-1])
fp32_param_list = list_lp16_use_fp32_params(target_dtype)
wrap_list = target_precision_ops if target_precision_ops is not None \
else list_lp16_ops(target_dtype)
- for fun_name in wrap_list:
+ for fun_name in get_aliases(wrap_list):
try:
- fun_name, cur_module = _get_fun_to_wrap(fun_name, module,
submodule_dict)
+ fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
fp32_param = fp32_param_list[fun_name] if (fp32_param_list and
fun_name in fp32_param_list) else None
setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype,
fp32_param=fp32_param))
- if cur_module == module:
+ if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype,
fp32_param=fp32_param))
except AttributeError:
raise
wrap_list = fp32_ops if fp32_ops is not None else
list_fp32_ops(target_dtype)
- for fun_name in wrap_list:
+ for fun_name in get_aliases(wrap_list):
try:
- fun_name, cur_module = _get_fun_to_wrap(fun_name, module,
submodule_dict)
+ fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
- if cur_module == module:
+ if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32))
except AttributeError:
raise
wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
else list_conditional_fp32_ops(target_dtype)
- for fun_name, arg, arg_values in wrap_list:
+ for fun_name, arg, arg_values in get_cond_aliases(wrap_list):
try:
- fun_name, cur_module = _get_fun_to_wrap(fun_name, module,
submodule_dict)
+ fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32,
cond_arg=(arg, arg_values)))
- if cur_module == module:
+ if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32,
cond_arg=(arg, arg_values)))
except AttributeError:
raise
- for fun_name in list_widest_type_cast(target_dtype):
+ for fun_name in get_aliases(list_widest_type_cast(target_dtype)):
try:
- fun_name, cur_module = _get_fun_to_wrap(fun_name, module,
submodule_dict)
+ fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
- if cur_module == module:
+ if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap))
except AttributeError:
raise
@@ -278,6 +299,14 @@ def scale_loss(loss, optimizer_or_trainer):
else:
yield optimizer_or_trainer._amp_loss_scaler.loss_scale * loss
+def warn_if_model_exists():
+ for f in inspect.stack():
+ for k, v in f.frame.f_locals.items():
+ if isinstance(v, Block):
+ logging.warning('Block %s created in [%s:%d] before AMP init.',
+ k, f.filename, f.lineno)
+ return
+
def init(target_dtype='float16', target_precision_ops=None,
conditional_fp32_ops=None, fp32_ops=None):
"""Initialize AMP (automatic mixed precision).
@@ -310,13 +339,39 @@ def init(target_dtype='float16',
target_precision_ops=None,
target_dtype = bfloat16
else:
target_dtype = np.dtype(target_dtype)
- _wrap_symbol_functions(symbol, target_dtype, target_precision_ops,
- conditional_fp32_ops, fp32_ops)
- _wrap_symbol_functions(ndarray, target_dtype, target_precision_ops,
- conditional_fp32_ops, fp32_ops)
+
+ warn_if_model_exists()
+
+ ops = get_all_registered_operators_grouped()
+ get_aliases_nd = lambda l: [a for op in l for a in ops[op] if not
base._is_np_op(a)]
+ get_aliases_np = lambda l: [a for op in l for a in ops[op] if
base._is_np_op(a)]
+ get_aliases_np_pub = lambda l: [a for op in l for a in ops[op]
+ if a.startswith(('_np_', '_npx_'))]
+ get_cond_aliases_nd = lambda l: [(a, *rest) for op, *rest in l for a
in ops[op]
+ if not base._is_np_op(a)]
+ get_cond_aliases_np = lambda l: [(a, *rest) for op, *rest in l for a
in ops[op]
+ if base._is_np_op(a)]
+ get_cond_aliases_np_pub = lambda l: [(a, *rest) for op, *rest in l for
a in ops[op]
+ if a.startswith(('_np_',
'_npx_'))]
+ sy_submodules = {p:getattr(symbol, p[1:-1]) for p in
base._OP_NAME_PREFIX_LIST}
+ get_sy_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod,
sy_submodules)
+ nd_submodules = {p:getattr(ndarray, p[1:-1]) for p in
base._OP_NAME_PREFIX_LIST}
+ get_nd_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod,
nd_submodules)
+ get_np_sy_fun = lambda fun, mod: _get_np_fun_to_wrap(fun,
"mxnet.symbol")
+ get_np_nd_fun = lambda fun, mod: _get_np_fun_to_wrap(fun,
"mxnet.ndarray")
+ get_np_fun = lambda fun, mode: _get_np_fun_to_wrap(fun, "mxnet")
+ todo = [
+ (symbol, False, get_aliases_nd, get_cond_aliases_nd, get_sy_fun),
+ (ndarray, False, get_aliases_nd, get_cond_aliases_nd, get_nd_fun),
+ (symbol.numpy, True, get_aliases_np, get_cond_aliases_np,
get_np_sy_fun),
+ (ndarray.numpy, True, get_aliases_np, get_cond_aliases_np,
get_np_nd_fun),
+ (numpy, True, get_aliases_np_pub, get_cond_aliases_np_pub,
get_np_fun),
+ ]
_loss_scaler = LossScaler()
- _wrap_loss_output_functions(ndarray, _loss_scaler, target_dtype)
- _wrap_loss_output_functions(symbol, _loss_scaler, target_dtype)
+ for module, is_numpy, get_aliases, get_cond_aliases, get_fun in todo:
+ _wrap_module_functions(module, is_numpy, target_dtype,
get_aliases, get_cond_aliases,
+ get_fun, target_precision_ops,
conditional_fp32_ops, fp32_ops)
+ _wrap_loss_output_functions(module, _loss_scaler, target_dtype)
def init_trainer(optimizer_or_trainer):
"""Initialize trainer or optimizer to work with AMP dynamic loss scaling.
@@ -339,8 +394,8 @@ def init_trainer(optimizer_or_trainer):
if isinstance(optimizer_or_trainer, trainer.Trainer):
optimizer_or_trainer._amp_loss_scaler = loss_scaler
optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
+ trainer.Trainer.amp_loss_scale = property(lambda self:
self._amp_loss_scaler.loss_scale)
elif isinstance(optimizer_or_trainer, opt.Optimizer):
- # TODO(ptredak): make it work with the optimizer
raise TypeError("AMP is currently only compatible with Gluon Trainer")
else:
raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
diff --git a/python/mxnet/contrib/amp/lists/symbol_fp16.py
b/python/mxnet/contrib/amp/lists/symbol_fp16.py
index 275d088..db608e4 100644
--- a/python/mxnet/contrib/amp/lists/symbol_fp16.py
+++ b/python/mxnet/contrib/amp/lists/symbol_fp16.py
@@ -18,8 +18,15 @@
# coding: utf-8
"""Lists of functions whitelisted/blacklisted for automatic mixed precision in
symbol API."""
+from ....runtime import Features
+
+
# Functions that should be cast to lower precision
FP16_FUNCS = [
+ '_linalg_gemm',
+ '_linalg_gemm2',
+ '_npi_einsum',
+ '_npi_matmul',
'Convolution',
'Deconvolution',
'FullyConnected',
@@ -35,18 +42,29 @@ FP16_FP32_FUNCS = [
'BilinearSampler',
'BlockGrad',
'Cast',
- 'cast',
'cast_storage',
+ '_contrib_BatchNormWithReLU',
+ '_contrib_allclose',
+ '_contrib_arange_like',
+ '_contrib_dynamic_reshape',
+ '_contrib_intgemm_fully_connected',
+ '_contrib_intgemm_maxabsolute',
+ '_contrib_intgemm_prepare_data',
+ '_contrib_intgemm_prepare_weight',
+ '_contrib_intgemm_take_weight',
+ '_contrib_quantized_batch_norm',
+ '_contrib_quantized_elemwise_mul',
+ '_contrib_quantized_embedding',
+ '_contrib_mrcnn_mask_target',
+ '_contrib_round_ste',
+ '_contrib_sign_ste',
'Crop',
'Dropout',
'Embedding',
- '_sparse_Embedding',
- '_sparse_FullyConnected',
'Flatten',
'GridGenerator',
'Pad',
'Pooling',
- 'Pooling_v1',
'ROIPooling',
'Reshape',
'SequenceLast',
@@ -57,31 +75,15 @@ FP16_FP32_FUNCS = [
'SwapAxis',
'UpSampling',
'_CachedOp',
+ '_CachedOpThreadSafe',
'_CrossDeviceCopy',
'_CustomFunction',
- '_DivScalar',
- '_EqualScalar',
- '_GreaterScalar',
- '_GreaterEqualScalar',
- '_LesserScalar',
- '_LesserEqualScalar',
- '_LogicalAndScalar',
- '_LogicalOrScalar',
- '_LogicalXorScalar',
- '_MaximumScalar',
- '_MinimumScalar',
- '_MinusScalar',
- '_ModScalar',
- '_MulScalar',
+ '_FusedOp',
+ '_FusedOpHelper',
+ '_FusedOpOutHelper',
'_NoGradient',
- '_NotEqualScalar',
- '_PlusScalar',
- '_RMinusScalar',
- '_RModScalar',
'_adamw_update',
- '_add',
'_arange',
- '_broadcast_backward',
'_cond',
'_contrib_AdaptiveAvgPooling2D',
'_contrib_BilinearResize2D',
@@ -109,8 +111,6 @@ FP16_FP32_FUNCS = [
'_contrib_requantize',
'_copy',
'_copyto',
- '_crop_assign',
- '_crop_assign_scalar',
'_cvcopyMakeBorder',
'_cvimdecode',
'_cvimread',
@@ -125,6 +125,7 @@ FP16_FP32_FUNCS = [
'_greater_scalar',
'_greater_equal_scalar',
'_histogram',
+ '_hypot_scalar',
'_identity_with_attr_like_rhs',
'_image_adjust_lighting',
'_image_flip_left_right',
@@ -133,6 +134,8 @@ FP16_FP32_FUNCS = [
'_image_random_brightness',
'_image_random_color_jitter',
'_image_random_contrast',
+ '_image_random_crop',
+ '_image_random_resized_crop',
'_image_random_flip_left_right',
'_image_random_flip_top_bottom',
'_image_random_hue',
@@ -152,7 +155,162 @@ FP16_FP32_FUNCS = [
'_mod_scalar',
'_mp_adamw_update',
'_mul_scalar',
+ '_multi_adamw_update',
+ '_multi_lamb_update',
+ '_multi_lans_update',
+ '_multi_mp_adamw_update',
+ '_multi_mp_lamb_update',
+ '_multi_mp_lans_update',
'_not_equal_scalar',
+ '_np_reshape',
+ '_npi_absolute',
+ '_npi_add',
+ '_npi_add_scalar',
+ '_npi_advanced_indexing',
+ '_npi_advanced_indexing_multiple',
+ '_npi_all',
+ '_npi_any',
+ '_npi_arange',
+ '_npi_arccosh',
+ '_npi_arcsinh',
+ '_npi_arctan',
+ '_npi_arctan2',
+ '_npi_arctan2_scalar',
+ '_npi_argmax',
+ '_npi_argmin',
+ '_npi_around',
+ '_npi_atleast_1d',
+ '_npi_atleast_2d',
+ '_npi_atleast_3d',
+ '_npi_bernoulli',
+ '_npi_bincount',
+ '_npi_bitwise_and',
+ '_npi_bitwise_and_scalar',
+ '_npi_bitwise_not',
+ '_npi_bitwise_or',
+ '_npi_bitwise_or_scalar',
+ '_npi_bitwise_xor',
+ '_npi_bitwise_xor_scalar',
+ '_npi_blackman',
+ '_npi_boolean_mask_assign_scalar',
+ '_npi_boolean_mask_assign_tensor',
+ '_npi_broadcast_to',
+ '_npi_cbrt',
+ '_npi_ceil',
+ '_npi_choice',
+ '_npi_copy',
+ '_npi_copysign_scalar',
+ '_npi_cos',
+ '_npi_degrees',
+ '_npi_delete',
+ '_npi_diag',
+ '_npi_diag_indices_from',
+ '_npi_diagflat',
+ '_npi_diagonal',
+ '_npi_diff',
+ '_npi_dsplit',
+ '_npi_equal_scalar',
+ '_npi_exponential',
+ '_npi_eye',
+ '_npi_fill_diagonal',
+ '_npi_fix',
+ '_npi_flip',
+ '_npi_floor',
+ '_npi_fmax_scalar',
+ '_npi_fmin_scalar',
+ '_npi_fmod_scalar',
+ '_npi_full',
+ '_npi_full_like',
+ '_npi_gamma',
+ '_npi_greater_equal_scalar',
+ '_npi_greater_scalar',
+ '_npi_gumbel',
+ '_npi_hamming',
+ '_npi_hanning',
+ '_npi_hsplit',
+ '_npi_identity',
+ '_npi_indices',
+ '_npi_insert_scalar',
+ '_npi_insert_slice',
+ '_npi_insert_tensor',
+ '_npi_interp',
+ '_npi_isinf',
+ '_npi_isfinite',
+ '_npi_isnan',
+ '_npi_isneginf',
+ '_npi_isposinf',
+ '_npi_laplace',
+ '_npi_less_equal_scalar',
+ '_npi_less_scalar',
+ '_npi_logistic',
+ '_npi_lcm',
+ '_npi_lcm_scalar',
+ '_npi_linspace',
+ '_npi_logical_not',
+ '_npi_logical_and_scalar',
+ '_npi_logical_or_scalar',
+ '_npi_logical_xor_scalar',
+ '_npi_logspace',
+ '_npi_max',
+ '_npi_min',
+ '_npi_mod',
+ '_npi_mod_scalar',
+ '_npi_moveaxis',
+ '_npi_multinomial',
+ '_npi_multiply',
+ '_npi_multiply_scalar',
+ '_npi_nan_to_num',
+ '_npi_negative',
+ '_npi_normal',
+ '_npi_normal_n',
+ '_npi_not_equal_scalar',
+ '_npi_ones',
+ '_npi_pad',
+ '_npi_pareto',
+ '_npi_percentile',
+ '_npi_powerd',
+ '_npi_radians',
+ '_npi_rarctan2_scalar',
+ '_npi_rayleigh',
+ '_npi_rcopysign_scalar',
+ '_npi_repeats',
+ '_npi_rfmod_scalar',
+ '_npi_rint',
+ '_npi_rmod_scalar',
+ '_npi_roll',
+ '_npi_rollaxis',
+ '_npi_rot90',
+ '_npi_rsubtract_scalar',
+ '_npi_rtrue_divide_scalar',
+ '_npi_share_memory',
+ '_npi_sign',
+ '_npi_sin',
+ '_npi_sqrt',
+ '_npi_squeeze',
+ '_npi_subtract',
+ '_npi_subtract_scalar',
+ '_npi_tanh',
+ '_npi_transpose',
+ '_npi_tri',
+ '_npi_tril',
+ '_npi_tril_indices',
+ '_npi_triu',
+ '_npi_true_divide',
+ '_npi_true_divide_scalar',
+ '_npi_trunc',
+ '_npi_uniform',
+ '_npi_uniform_n',
+ '_npi_unique',
+ '_npi_weibull',
+ '_npi_where_lscalar',
+ '_npi_where_rscalar',
+ '_npi_where_scalar2',
+ '_npi_zeros',
+ '_npx_constraint_check',
+ '_npx_nonzero',
+ '_npx_relu',
+ '_npx_reshape',
+ '_npx_sigmoid',
'_onehot_encode',
'_ones',
'_plus_scalar',
@@ -189,42 +347,9 @@ FP16_FP32_FUNCS = [
'_shuffle',
'_slice_assign',
'_slice_assign_scalar',
- '_sparse_abs',
'_sparse_adagrad_update',
- '_sparse_adam_update',
- '_sparse_arccosh',
- '_sparse_arcsinh',
- '_sparse_arctan',
- '_sparse_cast_storage',
- '_sparse_cbrt',
- '_sparse_ceil',
- '_sparse_clip',
- '_sparse_concat',
- '_sparse_cos',
- '_sparse_degrees',
- '_sparse_fix',
- '_sparse_floor',
- '_sparse_ftrl_update',
- '_sparse_negative',
- '_sparse_radians',
- '_sparse_relu',
'_sparse_retain',
- '_sparse_rint',
- '_sparse_round',
- '_sparse_sgd_mom_update',
- '_sparse_sgd_update',
- '_sparse_sigmoid',
- '_sparse_sign',
- '_sparse_sin',
- '_sparse_sinh',
- '_sparse_slice',
- '_sparse_sqrt',
- '_sparse_stop_gradient',
- '_sparse_tanh',
- '_sparse_trunc',
- '_sparse_zeros_like',
'_split_v2',
- '_split_v2_backward',
'_unravel_index',
'_zeros',
'_zeros_without_dtype',
@@ -240,16 +365,14 @@ FP16_FP32_FUNCS = [
'argmax_channel',
'argmin',
'batch_take',
- 'broadcast_axes',
'broadcast_axis',
'broadcast_like',
'broadcast_to',
'cbrt',
'ceil',
- 'choose_element_0index',
'clip',
+ 'col2im',
'cos',
- 'crop',
'degrees',
'depth_to_space',
'diag',
@@ -257,64 +380,52 @@ FP16_FP32_FUNCS = [
'expand_dims',
'fill_element_0index',
'fix',
- 'flatten',
- 'flip',
'floor',
'ftml_update',
'ftrl_update',
'gather_nd',
'hard_sigmoid',
- 'identity',
+ 'im2col',
+ 'lamb_update_phase1',
+ 'lamb_update_phase2',
'logical_not',
- 'max_axis',
'max',
'min',
- 'min_axis',
+ 'mp_lamb_update_phase1',
+ 'mp_lamb_update_phase2',
+ 'mp_nag_mom_update',
'mp_sgd_mom_update',
'mp_sgd_update',
'multi_all_finite',
+ 'multi_lars',
'multi_mp_sgd_mom_update',
'multi_mp_sgd_update',
'multi_sgd_mom_update',
'multi_sgd_update',
+ 'multi_sum_sq',
+ 'nag_mom_update',
'negative',
- 'normal',
'one_hot',
'ones_like',
- 'pad',
'pick',
+ 'preloaded_multi_mp_sgd_mom_update',
+ 'preloaded_multi_mp_sgd_update',
+ 'preloaded_multi_sgd_mom_update',
+ 'preloaded_multi_sgd_update',
'radians',
- 'random_exponential',
- 'random_gamma',
- 'random_generalized_negative_binomial',
- 'random_negative_binomial',
- 'random_normal',
- 'random_poisson',
- 'random_randint',
- 'random_uniform',
- 'ravel_multi_index',
'relu',
'repeat',
- 'reshape',
+ 'reset_arrays',
'reshape_like',
'reverse',
'rint',
'rmsprop_update',
'rmspropalex_update',
'round',
- 'sample_exponential',
- 'sample_gamma',
- 'sample_generalized_negative_binomial',
- 'sample_multinomial',
- 'sample_negative_binomial',
- 'sample_normal',
- 'sample_poisson',
- 'sample_uniform',
'scatter_nd',
'sgd_mom_update',
'sgd_update',
'shape_array',
- 'shuffle',
'sigmoid',
'sign',
'signsgd_update',
@@ -327,51 +438,55 @@ FP16_FP32_FUNCS = [
'softsign',
'sort',
'space_to_depth',
- 'split',
'sqrt',
'squeeze',
- 'stop_gradient',
- 'swapaxes',
'take',
'tanh',
'tile',
'transpose',
'trunc',
- 'uniform',
- 'unravel_index',
'zeros_like',
- '_sg_mkldnn_conv',
- '_sg_mkldnn_fully_connected',
- 'CuDNNBatchNorm',
- '_TensorRT',
]
+if Features().is_enabled('CUDNN'):
+ FP16_FP32_FUNCS.extend([
+ 'CuDNNBatchNorm',
+ ])
+
# Functions that have to be cast to FP32 due to possible
# overflows
FP32_FUNCS = [
- 'Convolution_v1',
'IdentityAttachKLSparseReg',
'arccos',
- '_sparse_arccos',
'arcsin',
'cosh',
- '_sparse_cosh',
'erfinv',
'sinh',
'tan',
- '_sparse_tan',
'arctanh',
- '_sparse_arcsin',
- '_sparse_arctanh',
+ '_contrib_calibrate_entropy',
'_contrib_MultiBoxDetection',
'_contrib_MultiBoxPrior',
'_contrib_MultiBoxTarget',
+ '_npi_arccos',
+ '_npi_arcsin',
+ '_npi_arctanh',
+ '_npi_cosh',
+ '_npi_sinh',
+ '_npi_tan',
# Exponents
+ '_npi_exp',
+ '_npi_expm1',
+ '_npi_ldexp',
+ '_npi_ldexp_scalar',
+ '_npi_log',
+ '_npi_log10',
+ '_npi_log1p',
+ '_npi_log2',
+ '_npi_rldexp_scalar',
'exp',
'expm1',
- '_sparse_exp',
- '_sparse_expm1',
'log',
'log10',
'log2',
@@ -380,30 +495,32 @@ FP32_FUNCS = [
# Powers
'broadcast_power',
'square',
- '_sparse_square',
'reciprocal',
- '_RDivScalar',
'_rdiv_scalar',
'rsqrt',
'rcbrt',
- '_Power',
- '_PowerScalar',
'_power',
'_power_scalar',
- '_RPowerScalar',
'_rpower_scalar',
- 'linalg_sumlogdiag',
- '_Hypot',
- '_HypotScalar',
- '_hypot',
- '_hypot_scalar',
- 'broadcast_hypot',
'_square_sum',
'_contrib_hawkesll',
+ '_npi_power',
+ '_npi_power_scalar',
+ '_npi_reciprocal',
+ '_npi_rpower_scalar',
+ '_npi_square',
# Reductions
+ '_npi_average',
+ '_npi_cumsum',
+ '_npi_mean',
+ '_npi_polyval',
+ '_npi_prod',
+ '_npi_std',
+ '_npi_sum',
+ '_npi_trace',
+ '_npi_var',
'sum',
- 'sum_axis',
'nansum',
'prod',
'nanprod',
@@ -414,11 +531,26 @@ FP32_FUNCS = [
'moments',
# Misc
+ '_npi_cholesky',
+ '_npi_eig',
+ '_npi_eigh',
+ '_npi_eigvals',
+ '_npi_eigvalsh',
+ '_npi_lstsq',
+ '_npi_matrix_rank',
+ '_npi_matrix_rank_none_tol',
+ '_npi_norm',
+ '_npi_pinv',
+ '_npi_pinv_scalar_rcond',
+ '_npi_qr',
+ '_npi_solve',
+ '_npi_svd',
+ '_npi_tensorinv',
+ '_npi_tensorsolve',
+ 'digamma',
'gamma',
'gammaln',
'_linalg_gelqf',
- '_linalg_gemm',
- '_linalg_gemm2',
'_linalg_potrf',
'_linalg_potri',
'_linalg_sumlogdiag',
@@ -433,42 +565,16 @@ FP32_FUNCS = [
'_linalg_inverse',
'_linalg_det',
'_linalg_slogdet',
- 'linalg_syrk',
- 'linalg_potrf',
- 'linalg_potri',
- 'linalg_gemm2',
- 'linalg_gemm',
- 'linalg_gelqf',
- 'linalg_trmm',
- 'linalg_trsm',
- 'linalg_makediag',
- 'linalg_extractdiag',
- 'linalg_maketrian',
- 'linalg_extracttrian',
- 'linalg_inverse',
- 'linalg_det',
- 'linalg_slogdet',
'_NDArray',
'_Native',
'_contrib_count_sketch',
'_contrib_SyncBatchNorm',
'_contrib_fft',
- '_sparse_gamma',
- '_sparse_gammaln',
- '_sparse_log',
- '_sparse_log10',
- '_sparse_log1p',
- '_sparse_log2',
- '_sparse_make_loss',
- '_sparse_mean',
- '_sparse_norm',
- '_sparse_rsqrt',
'argsort',
'topk',
# Neural network
'softmax',
- 'Softmax',
'log_softmax',
'InstanceNorm',
'LayerNorm',
@@ -482,13 +588,17 @@ FP32_FUNCS = [
'make_loss',
'Custom',
'CTCLoss',
- '_contrib_CTCLoss',
- '_contrib_ctc_loss',
- 'ctc_loss',
'_npx_deformable_convolution',
+ '_npx_modulated_deformable_convolution',
'_contrib_DeformablePSROIPooling',
]
+if Features().is_enabled('MKLDNN'):
+ FP32_FUNCS.extend([
+ '_sg_mkldnn_conv',
+ '_sg_mkldnn_fully_connected',
+ ])
+
# Functions that have to be cast to FP32 only for
# some values of their parameters
CONDITIONAL_FP32_FUNCS = [
@@ -499,51 +609,59 @@ CONDITIONAL_FP32_FUNCS = [
# Functions with multiple inputs, that need the same
# type of all their inputs
WIDEST_TYPE_CASTS = [
- '_Plus',
- '_plus',
- '_Minus',
- '_sub',
- '_Mul',
- '_Div',
- '_div',
- '_Mod',
- '_Not_Equal',
- '_Equal',
'_equal',
- '_Greater',
'_greater',
- '_Greater_Equal',
'_greater_equal',
- '_Lesser',
- '_Lesser_Equal',
+ '_hypot',
'_lesser',
'_lesser_equal',
- '_Logical_And',
- '_Logical_Or',
- '_Logical_Xor',
'_logical_and',
'_logical_or',
'_logical_xor',
'_maximum',
'_minimum',
- '_minus',
'_mod',
- '_mul',
'_not_equal',
+ '_npi_column_stack',
+ '_npi_concatenate',
+ '_npi_copysign',
+ '_npi_cross',
+ '_npi_dot',
+ '_npi_ediff1d',
+ '_npi_equal',
+ '_npi_fmax',
+ '_npi_fmin',
+ '_npi_fmod',
+ '_npi_greater',
+ '_npi_greater_equal',
+ '_npi_hypot',
+ '_npi_kron',
+ '_npi_less',
+ '_npi_less_equal',
+ '_npi_logical_and',
+ '_npi_logical_or',
+ '_npi_logical_xor',
+ '_npi_not_equal',
+ '_npi_dstack',
+ '_npi_hstack',
+ '_npi_stack',
+ '_npi_tensordot',
+ '_npi_tensordot_int_axes',
+ '_npi_vstack',
+ '_npi_where',
+ '_npx_index_add',
+ '_npx_index_update',
'Concat',
- 'concat',
+ '_contrib_RROIAlign',
'Correlation',
- 'ElementWiseSum',
- '_sparse_ElementWiseSum',
'add_n',
- '_sparse_add_n',
'batch_dot',
'broadcast_add',
- 'broadcast_plus',
'broadcast_div',
'broadcast_equal',
'broadcast_greater',
'broadcast_greater_equal',
+ 'broadcast_hypot',
'broadcast_lesser',
'broadcast_lesser_equal',
'broadcast_logical_and',
@@ -551,7 +669,6 @@ WIDEST_TYPE_CASTS = [
'broadcast_logical_xor',
'broadcast_maximum',
'broadcast_minimum',
- 'broadcast_minus',
'broadcast_mod',
'broadcast_mul',
'broadcast_not_equal',
@@ -562,15 +679,14 @@ WIDEST_TYPE_CASTS = [
'elemwise_mul',
'elemwise_sub',
'stack',
- '_Maximum',
- '_Minimum',
'_contrib_MultiProposal',
'_contrib_PSROIPooling',
'_contrib_Proposal',
'_contrib_ROIAlign',
+ '_contrib_box_decode',
+ '_contrib_box_encode',
'_contrib_box_iou',
'_contrib_box_nms',
- '_contrib_box_non_maximum_suppression',
'_contrib_dgl_adjacency',
'_contrib_dgl_csr_neighbor_non_uniform_sample',
'_contrib_dgl_csr_neighbor_uniform_sample',
@@ -582,28 +698,7 @@ WIDEST_TYPE_CASTS = [
'_contrib_interleaved_matmul_selfatt_qk',
'_contrib_interleaved_matmul_selfatt_valatt',
'where',
- '_sparse_where',
- '_sparse_broadcast_add',
- '_sparse_broadcast_div',
- '_sparse_broadcast_minus',
- '_sparse_broadcast_mul',
- '_sparse_broadcast_plus',
- '_sparse_broadcast_sub',
- '_sparse_dot',
- '_sparse_elemwise_add',
- '_sparse_elemwise_div',
- '_sparse_elemwise_mul',
- '_sparse_elemwise_sub',
- '_sparse_sum',
- 'random_pdf_gamma',
- 'random_pdf_exponential',
- 'random_pdf_uniform',
- 'random_pdf_negative_binomial',
- 'random_pdf_generalized_negative_binomial',
- 'random_pdf_dirichlet',
- 'random_pdf_normal',
- 'random_pdf_poisson',
'_random_pdf_gamma',
'_random_pdf_exponential',
'_random_pdf_uniform',
diff --git a/python/mxnet/contrib/amp/loss_scaler.py
b/python/mxnet/contrib/amp/loss_scaler.py
index 3a177ce..771408e 100644
--- a/python/mxnet/contrib/amp/loss_scaler.py
+++ b/python/mxnet/contrib/amp/loss_scaler.py
@@ -19,9 +19,9 @@
"""Dynamic loss scaler for AMP."""
import logging
-from ...ndarray import multi_all_finite
-from ...ndarray import ndarray as nd
from ... import autograd as ag
+from ... import ndarray
+from ...util import is_np_array
class LossScaler(object):
"""Dynamic loss scaler for AMP.
@@ -44,15 +44,21 @@ class LossScaler(object):
def has_overflow(self, params):
"""Check gradients for overflow."""
+ if is_np_array():
+ all_finite_f = ndarray.numpy._internal.multi_all_finite
+ ones_f = ndarray.numpy.ones
+ else:
+ all_finite_f = ndarray.multi_all_finite
+ ones_f = ndarray.ones
with ag.pause():
chunk_size = 200
valid_params = [p._grad[0] for p in params if p._grad is not None]
- gpu_output = nd.ones((1,), ctx=valid_params[0].context)
+ gpu_output = ones_f((1,), ctx=valid_params[0].context)
nb_params = len(valid_params)
for idx in range(0, nb_params, chunk_size):
- multi_all_finite(*valid_params[idx:idx+chunk_size],
-
num_arrays=len(valid_params[idx:idx+chunk_size]),
- init_output=False, out=gpu_output)
+ all_finite_f(*valid_params[idx:idx+chunk_size],
+ num_arrays=len(valid_params[idx:idx+chunk_size]),
+ init_output=False, out=gpu_output)
has_overflow = not bool(gpu_output.asnumpy())
self._loss_scale = self._next_loss_scale
if has_overflow:
diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index 04efbee..b34fe0d 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -1143,6 +1143,32 @@ def get_all_registered_operators():
return mx_registered_operator_names
+def get_all_registered_operators_grouped():
+ """Get all registered MXNet operator names, grouped by 'original' operator.
+
+ Returns
+ -------
+ names : a dictionary, mapping op name to the list of all its aliases
(including the original).
+ """
+ ret = {}
+ for aname in get_all_registered_operators():
+ op_handle = OpHandle()
+ check_call(_LIB.NNGetOpHandle(c_str(aname), ctypes.byref(op_handle)))
+ name = ctypes.c_char_p()
+ desc = ctypes.c_char_p()
+ num_args = mx_uint()
+ arg_names = ctypes.POINTER(ctypes.c_char_p)()
+ arg_types = ctypes.POINTER(ctypes.c_char_p)()
+ arg_descs = ctypes.POINTER(ctypes.c_char_p)()
+ ret_types = ctypes.POINTER(ctypes.c_char_p)()
+ check_call(_LIB.NNGetOpInfo(op_handle, ctypes.byref(name),
ctypes.byref(desc),
+ ctypes.byref(num_args),
ctypes.byref(arg_names),
+ ctypes.byref(arg_types),
ctypes.byref(arg_descs),
+ ctypes.byref(ret_types)))
+ ret.setdefault(py_str(name.value), []).append(aname)
+ return ret
+
+
OperatorArguments = collections.namedtuple('OperatorArguments', ['narg',
'names', 'types'])
diff --git a/src/operator/contrib/all_finite.cc
b/src/operator/contrib/all_finite.cc
index 5e77510..8ef3671 100755
--- a/src/operator/contrib/all_finite.cc
+++ b/src/operator/contrib/all_finite.cc
@@ -97,6 +97,7 @@ inline void MultiAllFiniteCPU(const nnvm::NodeAttrs& attrs,
DMLC_REGISTER_PARAMETER(AllFiniteParam);
NNVM_REGISTER_OP(all_finite)
+.add_alias("_npi_all_finite")
.describe(R"code(Check if all the float numbers in the array are finite (used
for AMP)
)code" ADD_FILELINE)
.set_num_inputs(1)
@@ -129,6 +130,7 @@ NNVM_REGISTER_OP(all_finite)
DMLC_REGISTER_PARAMETER(MultiAllFiniteParam);
NNVM_REGISTER_OP(multi_all_finite)
+.add_alias("_npi_multi_all_finite")
.describe(R"code(Check if all the float numbers in all the arrays are finite
(used for AMP)
)code" ADD_FILELINE)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index 7690783..0acb699 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -115,6 +115,7 @@ inline static bool AMPMultiCastStorageType(const
nnvm::NodeAttrs& attrs, const i
#endif // MXNET_USE_MKLDNN == 1
NNVM_REGISTER_OP(amp_cast)
+.add_alias("_npi_amp_cast")
.describe(R"code(Cast function between low precision float/FP32 used by AMP.
It casts only between low precision float/FP32 and does not do anything for
other types.
@@ -158,6 +159,7 @@ NNVM_REGISTER_OP(_backward_amp_cast)
.set_attr<FCompute>("FCompute<cpu>", AMPCastCompute<cpu>);
NNVM_REGISTER_OP(amp_multicast)
+.add_alias("_npi_amp_multicast")
.describe(R"code(Cast function used by AMP, that casts its inputs to the
common widest type.
It casts only between low precision float/FP32 and does not do anything for
other types.
diff --git a/tests/python/gpu/test_contrib_amp.py
b/tests/python/gpu/test_contrib_amp.py
index d7a6e80..6895723 100644
--- a/tests/python/gpu/test_contrib_amp.py
+++ b/tests/python/gpu/test_contrib_amp.py
@@ -29,6 +29,7 @@ from mxnet.test_utils import set_default_context,
same_symbol_structure
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
+from mxnet.operator import get_all_registered_operators_grouped
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed, teardown_module,
assert_raises_cudnn_not_satisfied
@@ -42,7 +43,6 @@ def amp_tests(request):
request.addfinalizer(teardown)
[email protected](reason='Error during waitall(). Tracked in #18099')
def test_amp_coverage(amp_tests):
conditional = [item[0] for item in
amp.lists.symbol_fp16.CONDITIONAL_FP32_FUNCS]
@@ -66,42 +66,36 @@ def test_amp_coverage(amp_tests):
assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP
list."
# Check the coverage
- py_str = lambda x: x.decode('utf-8')
-
- plist = ctypes.POINTER(ctypes.c_char_p)()
- size = ctypes.c_uint()
-
- mx.base._LIB.MXListAllOpNames(ctypes.byref(size),
- ctypes.byref(plist))
- op_names = []
- for i in range(size.value):
- s = py_str(plist[i])
- if not s.startswith("_backward") \
- and not s.startswith("_contrib_backward_"):
- op_names.append(s)
-
- ret1 = set(op_names) - set(t)
-
- if ret1 != set():
- warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists
(in "
- "python/mxnet/contrib/amp/lists/symbol_fp16.py) -
please add them. "
- """Please follow these guidelines for choosing a proper
list:
- - if your operator is not to be used in a computational
graph
- (e.g. image manipulation operators, optimizers) or
does not have
- inputs, put it in FP16_FP32_FUNCS list,
- - if your operator requires FP32 inputs or is not safe
to use with lower
- precision, put it in FP32_FUNCS list,
- - if your operator supports both FP32 and lower
precision, has
- multiple inputs and expects all inputs to be of the
same
- type, put it in WIDEST_TYPE_CASTS list,
- - if your operator supports both FP32 and lower
precision and has
- either a single input or supports inputs of different
type,
- put it in FP16_FP32_FUNCS list,
- - if your operator is both safe to use in lower
precision and
- it is highly beneficial to use it in lower precision,
then
- put it in FP16_FUNCS (this is unlikely for new
operators)
- - If you are not sure which list to choose, FP32_FUNCS
is the
- safest option""")
+ covered = set(t)
+ ops = get_all_registered_operators_grouped()
+ required = set(k for k in ops
+ if not k.startswith(("_backward", "_contrib_backward",
"_npi_backward")) and
+ not k.endswith("_backward"))
+
+ extra = covered - required
+ assert not extra, f"{len(extra)} operators are not needed in the AMP
lists: {sorted(extra)}"
+
+ guidelines = """Please follow these guidelines for choosing a proper list:
+ - if your operator is not to be used in a computational graph
+ (e.g. image manipulation operators, optimizers) or does not have
+ inputs, put it in FP16_FP32_FUNCS list,
+ - if your operator requires FP32 inputs or is not safe to use with lower
+ precision, put it in FP32_FUNCS list,
+ - if your operator supports both FP32 and lower precision, has
+ multiple inputs and expects all inputs to be of the same
+ type, put it in WIDEST_TYPE_CASTS list,
+ - if your operator supports both FP32 and lower precision and has
+ either a single input or supports inputs of different type,
+ put it in FP16_FP32_FUNCS list,
+ - if your operator is both safe to use in lower precision and
+ it is highly beneficial to use it in lower precision, then
+ put it in FP16_FUNCS (this is unlikely for new operators)
+ - If you are not sure which list to choose, FP32_FUNCS is the
+ safest option"""
+ diff = required - covered
+ assert not diff, f"{len(diff)} operators {sorted(diff)} do not exist in
AMP lists (in " \
+ f"python/mxnet/contrib/amp/lists/symbol_fp16.py) - please add them. " \
+ f"\n{guidelines}"
@with_seed()
@pytest.mark.skip(reason='Error during waitall(). Tracked in #18099')
@@ -120,7 +114,6 @@ def test_amp_conversion_rnn(amp_tests):
@with_seed()
[email protected](reason='Error during waitall(). Tracked in #18099')
def test_fp16_casting(amp_tests):
data = mx.sym.var("data")
out1 = mx.sym.amp_cast(data, dtype="float16")