This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f6d1ed1872 Improve bf16 support (#21002)
f6d1ed1872 is described below

commit f6d1ed1872b745c3a9dd944b7ad5c2ce988b014b
Author: Paweł Głomski <[email protected]>
AuthorDate: Fri Jul 15 10:56:38 2022 +0200

    Improve bf16 support (#21002)
    
    * AMP improvements + enable bf16 input for quantize_v2
    
    * Fix sanity
    
    * Improve tests, AMP conversion interface, fix forwad hooks
    
    * Fix tests
    
    * Fix imports in tests
    
    * Use different lp16_fp32 op in test
    
    * Add amp.disable_amp() context, fix tests
    
    * Add tests, generalize optimization disabling
    
    * Fix sanity
    
    * Review fixes
    
    * Use is_integral<>::value
    
    * Review fixes
    Change flag type to unsigned int
    Add a warning for an incorrect flag attribute value
    
    * Extend bf16 support
    
    * Combine enable_float_output and amp_out_dtype parameters
    
    * Add bf16 support to _dnnl_batch_dot
    
    * Fix sanity
    
    * Add bf16 support to all dnnl ops, add tests
    
    * Add license
    
    * Fix conv activation fuse, disable masked_softmax bf16 support
    
    * Fix sanity, add softmax test cases
    
    * Compare bf16 outputs with fp32 reference
    
    Co-authored-by: Bartlomiej Gawrych <[email protected]>
---
 3rdparty/mshadow/mshadow/base.h                    |  20 +-
 3rdparty/mshadow/mshadow/bfloat.h                  |   2 +
 include/mxnet/imperative.h                         |   1 +
 python/mxnet/amp/lists/symbol_bf16.py              | 688 +++++++++++++++------
 src/common/utils.h                                 |   7 +-
 src/imperative/imperative.cc                       |   2 +-
 src/nnvm/low_precision_pass.cc                     |   6 +-
 src/operator/nn/dnnl/dnnl_base-inl.h               |   9 +
 src/operator/nn/dnnl/dnnl_batch_dot-inl.h          |   6 +-
 src/operator/nn/dnnl/dnnl_batch_dot.cc             |   4 +-
 src/operator/nn/dnnl/dnnl_convolution-inl.h        |  10 +-
 src/operator/nn/dnnl/dnnl_fully_connected-inl.h    |  12 +-
 src/operator/nn/dnnl/dnnl_reduce.cc                |   2 +-
 src/operator/numpy/np_elemwise_broadcast_op.h      |   6 +-
 src/operator/numpy/np_true_divide-inl.h            |   6 +-
 src/operator/random/sample_op.h                    |  34 +-
 src/operator/subgraph/dnnl/dnnl_batch_dot.cc       |  34 +-
 src/operator/subgraph/dnnl/dnnl_conv.cc            |  32 +-
 src/operator/subgraph/dnnl/dnnl_fc.cc              |  37 +-
 .../subgraph/dnnl/dnnl_fc_sum_fuse_property.h      |   8 +-
 .../subgraph/dnnl/dnnl_post_amp_property.h         |   2 +-
 .../subgraph/dnnl/dnnl_post_quantize_property.h    |   2 +-
 src/operator/subgraph/dnnl/dnnl_transformer-inl.h  |  14 +-
 src/operator/subgraph/dnnl/dnnl_transformer.cc     |  39 +-
 src/operator/tensor/elemwise_unary_op.h            |   8 +-
 tests/python/dnnl/op_cfg.py                        | 410 ++++++++++++
 tests/python/dnnl/subgraphs/subgraph_common.py     |   4 +-
 tests/python/dnnl/subgraphs/test_amp_subgraph.py   |  31 +-
 tests/python/dnnl/subgraphs/test_fc_subgraph.py    |   2 +-
 tests/python/dnnl/test_amp.py                      |  66 ++
 30 files changed, 1152 insertions(+), 352 deletions(-)

diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h
index a160b52201..6181b3e19a 100644
--- a/3rdparty/mshadow/mshadow/base.h
+++ b/3rdparty/mshadow/mshadow/base.h
@@ -409,7 +409,7 @@ struct DataType<half::half_t> {
 #endif
 #endif
 };
-template<>
+template <>
 struct DataType<bfloat::bf16_t> {
   static const int kFlag = kBfloat16;
   static const int kLanes = 1;
@@ -769,6 +769,10 @@ namespace isnan_typed {
   MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
     return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
   }
+  template <>
+  MSHADOW_XINLINE bool IsNan(volatile mshadow::bfloat::bf16_t val) {
+    return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) > MSHADOW_BF16_EXPONENT_BITS;
+  }
 }  // namespace isnan_typed
 
 /*! \brief
@@ -795,6 +799,10 @@ namespace isinf_typed {
   MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
     return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == 
MSHADOW_HALF_EXPONENT_BITS;
   }
+  template <>
+  MSHADOW_XINLINE bool IsInf(volatile mshadow::bfloat::bf16_t val) {
+    return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) == 
MSHADOW_BF16_EXPONENT_BITS;
+  }
 }  // namespace isinf_typed
 
 /*! \brief namespace for potential reducer operations */
@@ -881,6 +889,11 @@ MSHADOW_XINLINE half::half_t 
NegInfValue<half::half_t>(void) {
   return half::half_t::Binary(
       MSHADOW_HALF_SIGN_BIT | MSHADOW_HALF_EXPONENT_BITS);
 }
+/*! \brief negative infinity value of bfloat16 */
+template <>
+MSHADOW_XINLINE bfloat::bf16_t NegInfValue<bfloat::bf16_t>(void) {
+  return bfloat::bf16_t::Binary(MSHADOW_BF16_SIGN_BIT | 
MSHADOW_BF16_EXPONENT_BITS);
+}
 
 /*!
  * \brief maximum value of certain types
@@ -962,6 +975,11 @@ template<>
 MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
   return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
 }
+/*! \brief positive infinity value of bfloat16 */
+template <>
+MSHADOW_XINLINE bfloat::bf16_t PosInfValue<bfloat::bf16_t>(void) {
+  return bfloat::bf16_t::Binary(MSHADOW_BF16_EXPONENT_BITS);
+}
 
 }  // namespace limits
 
diff --git a/3rdparty/mshadow/mshadow/bfloat.h 
b/3rdparty/mshadow/mshadow/bfloat.h
index 94bbb4acf5..8a9f5b0aaf 100644
--- a/3rdparty/mshadow/mshadow/bfloat.h
+++ b/3rdparty/mshadow/mshadow/bfloat.h
@@ -180,6 +180,8 @@ MSHADOW_BF16_OPERATOR(bool, <=)
 
 #define MSHADOW_BF16_MIN mshadow::bfloat::bf16_t::Binary(0xFF7F);
 #define MSHADOW_BF16_MAX mshadow::bfloat::bf16_t::Binary(0x7F7F);
+#define MSHADOW_BF16_SIGN_BIT      0x8000
+#define MSHADOW_BF16_EXPONENT_BITS 0x7f80
 }  // namespace bfloat
 }  // namespace mshadow
 #endif  // MSHADOW_BFLOAT_H_
\ No newline at end of file
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 42876f7bf4..75cf3e2a80 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -42,6 +42,7 @@ enum class OptConstraint : unsigned int {
   DisableAMP = 1 << 0
   // DisableQuantization = 1 << 1
 };
+using OptConstraint_int_t = std::underlying_type_t<OptConstraint>;
 
 /*! \brief there are three numpy shape flags based on priority.
  * GlobalOn
diff --git a/python/mxnet/amp/lists/symbol_bf16.py 
b/python/mxnet/amp/lists/symbol_bf16.py
index 566990c411..89ddea2820 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -40,19 +40,69 @@ if Features.instance.is_enabled('ONEDNN'):
 # like image transformations or optimizers) or they
 # are dtype neutral (can work in both bf16 and fp32)
 BF16_FP32_FUNCS = [
-    'abs',
+    '_contrib_AdaptiveAvgPooling2D',
+    '_contrib_BatchNormWithReLU',
+    'Activation',
     'BatchNorm',
-    'clip',
-    'Flatten',
+    'LayerNorm',
     'LRN',
+    'softmax',
+    'log_softmax',
+    #'masked_softmax', TODO: fix segfault appearing for a 4D input tensor
     'Pooling',
-    'relu',
-    '_shuffle',
-    'sqrt',
-    'square',
-    'tanh',
+    '_npi_mean',
+    '_npi_sum',
+    '_npi_square',
+    '_npi_sqrt',
+    '_npi_exp',
+    '_npi_tanh',
+    '_npi_transpose',
+    '_npx_reshape',
+    '_npi_where',
+    #'_contrib_quantize_asym', # used in rnn, which is hard to convert to bf16
     '_contrib_quantize_v2',
+    #'_contrib_quantize', # not used anymore
+    'sum',
+    'mean',
+    '_copy',
+    'Reshape',
+    'Flatten',
+    'transpose',
+    'expand_dims',
+    'slice',
+    'stack',
+    'space_to_depth',
+    '_split_v2',
+
+    # no oneDNN support:
+    'Cast',
+    'where',
+    'take',
 ]
+# 'RNN', # GetEnv("MXNET_USE_ONEDNN_RNN", 1)
+
+# Functions with multiple inputs, that need the same
+# type of all their inputs
+WIDEST_TYPE_CASTS = [
+    'Concat',
+    'dot',
+    'batch_dot',
+    'broadcast_add',
+    'broadcast_sub',
+    'broadcast_mul',
+    'broadcast_div',
+    'elemwise_add',
+    'add_n',
+    '_npi_dot',
+    '_npi_add',
+    '_npi_multiply',
+    '_npi_subtract',
+    '_npi_true_divide',
+]
+if Features.instance.is_enabled('ONEDNN'):
+    WIDEST_TYPE_CASTS.extend([
+        '_sg_onednn_batch_dot',
+    ])
 
 # Functions that when running with Bfloat16, the params that still need 
float32.
 BF16_USE_FP32_PARAMS = {
@@ -63,106 +113,439 @@ BF16_USE_FP32_PARAMS = {
 # Functions that have to be cast to FP32 due to possible
 # overflows
 FP32_FUNCS = [
-    'RNN',
+    'amp_cast',
+    'amp_multicast',
+    'masked_softmax',
     'BilinearSampler',
     'BlockGrad',
-    'Cast',
-    'cast_storage',
+    'CTCLoss',
+    'Correlation',
     'Crop',
+    'Custom',
     'Dropout',
     'Embedding',
     'GridGenerator',
+    'GroupNorm',
+    'IdentityAttachKLSparseReg',
+    'InstanceNorm',
+    'L2Normalization',
+    'LinearRegressionOutput',
+    'LogisticRegressionOutput',
+    'MAERegressionOutput',
+    'MakeLoss',
     'Pad',
+    'RNN',
     'ROIPooling',
-    'Reshape',
+    'SVMOutput',
     'SequenceLast',
     'SequenceMask',
     'SequenceReverse',
     'SliceChannel',
+    'SoftmaxActivation',
+    'SoftmaxOutput',
     'SpatialTransformer',
     'SwapAxis',
     'UpSampling',
     '_CachedOp',
+    '_CachedOpThreadSafe',
     '_CrossDeviceCopy',
     '_CustomFunction',
+    '_NDArray',
+    '_Native',
     '_NoGradient',
+    '_adabelief_update',
     '_adamw_update',
     '_arange',
     '_cond',
-    '_contrib_interleaved_matmul_selfatt_qk',
-    '_contrib_interleaved_matmul_selfatt_valatt',
-    '_contrib_AdaptiveAvgPooling2D',
     '_contrib_BilinearResize2D',
+    '_contrib_DeformablePSROIPooling',
+    '_contrib_MultiBoxDetection',
+    '_contrib_MultiBoxPrior',
+    '_contrib_MultiBoxTarget',
+    '_contrib_MultiProposal',
+    '_contrib_PSROIPooling',
+    '_contrib_Proposal',
+    '_contrib_ROIAlign',
+    '_contrib_RROIAlign',
+    '_contrib_SyncBatchNorm',
+    '_contrib_allclose',
+    '_contrib_arange_like',
     '_contrib_bipartite_matching',
+    '_contrib_boolean_mask',
+    '_contrib_box_decode',
+    '_contrib_box_encode',
+    '_contrib_box_iou',
+    '_contrib_box_nms',
+    '_contrib_calibrate_entropy',
+    '_contrib_count_sketch',
     '_contrib_dequantize',
+    '_contrib_dgl_adjacency',
+    '_contrib_dgl_csr_neighbor_non_uniform_sample',
+    '_contrib_dgl_csr_neighbor_uniform_sample',
+    '_contrib_dgl_graph_compact',
+    '_contrib_dgl_subgraph',
     '_contrib_div_sqrt_dim',
-    '_contrib_boolean_mask',
+    '_contrib_dynamic_reshape',
+    '_contrib_edge_id',
+    '_contrib_fft',
     '_contrib_getnnz',
     '_contrib_gradientmultiplier',
     '_contrib_group_adagrad_update',
+    '_contrib_hawkesll',
     '_contrib_index_array',
     '_contrib_index_copy',
+    '_contrib_interleaved_matmul_encdec_qk',
+    '_contrib_interleaved_matmul_encdec_valatt',
+    '_contrib_interleaved_matmul_selfatt_qk',
+    '_contrib_interleaved_matmul_selfatt_valatt',
+    '_contrib_intgemm_fully_connected',
+    '_contrib_intgemm_maxabsolute',
+    '_contrib_intgemm_prepare_data',
+    '_contrib_intgemm_prepare_weight',
+    '_contrib_intgemm_take_weight',
     '_contrib_quadratic',
     '_contrib_quantize',
     '_contrib_quantize_asym',
+    '_contrib_quantized_act',
+    '_contrib_quantized_batch_norm',
     '_contrib_quantized_concat',
     '_contrib_quantized_conv',
+    '_contrib_quantized_elemwise_add',
+    '_contrib_quantized_elemwise_mul',
+    '_contrib_quantized_embedding',
     '_contrib_quantized_flatten',
     '_contrib_quantized_fully_connected',
     '_contrib_quantized_pooling',
-    '_contrib_quantized_elemwise_add',
-    '_contrib_quantized_act',
+    '_contrib_quantized_reshape',
     '_contrib_quantized_rnn',
-    '_image_crop',
-    '_linspace',
+    '_contrib_quantized_transpose',
     '_contrib_requantize',
-    '_copy',
+    '_contrib_round_ste',
+    '_contrib_sign_ste',
+    '_contrib_sldwin_atten_context',
+    '_contrib_sldwin_atten_mask_like',
+    '_contrib_sldwin_atten_score',
     '_copyto',
     '_cvcopyMakeBorder',
     '_cvimdecode',
     '_cvimread',
     '_cvimresize',
     '_div_scalar',
+    '_equal',
     '_equal_scalar',
     '_eye',
     '_foreach',
-    '_while_loop',
     '_full',
     '_grad_add',
-    '_greater_scalar',
+    '_greater',
+    '_greater_equal',
     '_greater_equal_scalar',
+    '_greater_scalar',
     '_histogram',
+    '_hypot',
+    '_hypot_scalar',
     '_identity_with_attr_like_rhs',
     '_image_adjust_lighting',
+    '_image_crop',
     '_image_flip_left_right',
     '_image_flip_top_bottom',
     '_image_normalize',
     '_image_random_brightness',
     '_image_random_color_jitter',
     '_image_random_contrast',
+    '_image_random_crop',
     '_image_random_flip_left_right',
     '_image_random_flip_top_bottom',
     '_image_random_hue',
     '_image_random_lighting',
+    '_image_random_resized_crop',
     '_image_random_saturation',
     '_image_resize',
     '_image_to_tensor',
     '_imdecode',
-    '_lesser_scalar',
+    '_lesser',
+    '_lesser_equal',
     '_lesser_equal_scalar',
+    '_lesser_scalar',
+    '_linalg_det',
+    '_linalg_extractdiag',
+    '_linalg_extracttrian',
+    '_linalg_gelqf',
+    '_linalg_gemm',
+    '_linalg_gemm2',
+    '_linalg_inverse',
+    '_linalg_makediag',
+    '_linalg_maketrian',
+    '_linalg_potrf',
+    '_linalg_potri',
+    '_linalg_slogdet',
+    '_linalg_sumlogdiag',
+    '_linalg_syevd',
+    '_linalg_syrk',
+    '_linalg_trmm',
+    '_linalg_trsm',
+    '_linspace',
+    '_logical_and',
     '_logical_and_scalar',
+    '_logical_or',
     '_logical_or_scalar',
+    '_logical_xor',
     '_logical_xor_scalar',
+    '_maximum',
     '_maximum_scalar',
+    '_minimum',
     '_minimum_scalar',
     '_minus_scalar',
+    '_mod',
     '_mod_scalar',
+    '_mp_adabelief_update',
     '_mp_adamw_update',
     '_mul_scalar',
+    '_multi_adabelief_update',
+    '_multi_adamw_update',
+    '_multi_lamb_update',
+    '_multi_lans_update',
+    '_multi_mp_adabelief_update',
+    '_multi_mp_adamw_update',
+    '_multi_mp_lamb_update',
+    '_multi_mp_lans_update',
+    '_not_equal',
     '_not_equal_scalar',
+    '_np_reshape',
+    '_npi_absolute',
+    '_npi_add_scalar',
+    '_npi_advanced_indexing',
+    '_npi_advanced_indexing_multiple',
+    '_npi_all',
+    '_npi_any',
+    '_npi_arange',
+    '_npi_arccos',
+    '_npi_arccosh',
+    '_npi_arcsin',
+    '_npi_arcsinh',
+    '_npi_arctan',
+    '_npi_arctan2',
+    '_npi_arctan2_scalar',
+    '_npi_arctanh',
+    '_npi_argmax',
+    '_npi_argmin',
+    '_npi_around',
+    '_npi_atleast_1d',
+    '_npi_atleast_2d',
+    '_npi_atleast_3d',
+    '_npi_average',
+    '_npi_bernoulli',
+    '_npi_bincount',
+    '_npi_bitwise_and',
+    '_npi_bitwise_and_scalar',
+    '_npi_bitwise_left_shift',
+    '_npi_bitwise_left_shift_scalar',
+    '_npi_bitwise_not',
+    '_npi_bitwise_or',
+    '_npi_bitwise_or_scalar',
+    '_npi_bitwise_right_shift',
+    '_npi_bitwise_right_shift_scalar',
+    '_npi_bitwise_xor',
+    '_npi_bitwise_xor_scalar',
+    '_npi_blackman',
+    '_npi_boolean_mask_assign_scalar',
+    '_npi_boolean_mask_assign_tensor',
+    '_npi_broadcast_to',
+    '_npi_cbrt',
+    '_npi_ceil',
+    '_npi_choice',
+    '_npi_cholesky',
+    '_npi_column_stack',
+    '_npi_copy',
+    '_npi_copysign',
+    '_npi_copysign_scalar',
+    '_npi_cos',
+    '_npi_cosh',
+    '_npi_cross',
+    '_npi_cumsum',
+    '_npi_degrees',
+    '_npi_delete',
+    '_npi_diag',
+    '_npi_diag_indices_from',
+    '_npi_diagflat',
+    '_npi_diagonal',
+    '_npi_diff',
+    '_npi_dsplit',
+    '_npi_dstack',
+    '_npi_ediff1d',
+    '_npi_eig',
+    '_npi_eigh',
+    '_npi_eigvals',
+    '_npi_eigvalsh',
+    '_npi_einsum',
+    '_npi_equal',
+    '_npi_equal_scalar',
+    '_npi_expm1',
+    '_npi_exponential',
+    '_npi_eye',
+    '_npi_fill_diagonal',
+    '_npi_fix',
+    '_npi_flip',
+    '_npi_floor',
+    '_npi_floor_divide',
+    '_npi_floor_divide_scalar',
+    '_npi_fmax',
+    '_npi_fmax_scalar',
+    '_npi_fmin',
+    '_npi_fmin_scalar',
+    '_npi_fmod',
+    '_npi_fmod_scalar',
+    '_npi_full',
+    '_npi_full_like',
+    '_npi_gamma',
+    '_npi_gcd',
+    '_npi_gcd_scalar',
+    '_npi_greater',
+    '_npi_greater_equal',
+    '_npi_greater_equal_scalar',
+    '_npi_greater_scalar',
+    '_npi_gumbel',
+    '_npi_hamming',
+    '_npi_hanning',
+    '_npi_hsplit',
+    '_npi_hstack',
+    '_npi_hypot',
+    '_npi_identity',
+    '_npi_indices',
+    '_npi_insert_scalar',
+    '_npi_insert_slice',
+    '_npi_insert_tensor',
+    '_npi_interp',
+    '_npi_isfinite',
+    '_npi_isinf',
+    '_npi_isnan',
+    '_npi_isneginf',
+    '_npi_isposinf',
+    '_npi_kron',
+    '_npi_laplace',
+    '_npi_lcm',
+    '_npi_lcm_scalar',
+    '_npi_ldexp',
+    '_npi_ldexp_scalar',
+    '_npi_less',
+    '_npi_less_equal',
+    '_npi_less_equal_scalar',
+    '_npi_less_scalar',
+    '_npi_linspace',
+    '_npi_log',
+    '_npi_log10',
+    '_npi_log1p',
+    '_npi_log2',
+    '_npi_logaddexp',
+    '_npi_logaddexp_scalar',
+    '_npi_logical_and',
+    '_npi_logical_and_scalar',
+    '_npi_logical_not',
+    '_npi_logical_or',
+    '_npi_logical_or_scalar',
+    '_npi_logical_xor',
+    '_npi_logical_xor_scalar',
+    '_npi_logistic',
+    '_npi_logspace',
+    '_npi_lstsq',
+    '_npi_matmul',
+    '_npi_matrix_rank',
+    '_npi_matrix_rank_none_tol',
+    '_npi_max',
+    '_npi_min',
+    '_npi_mod',
+    '_npi_mod_scalar',
+    '_npi_moveaxis',
+    '_npi_multinomial',
+    '_npi_multiply_scalar',
+    '_npi_nan_to_num',
+    '_npi_negative',
+    '_npi_norm',
+    '_npi_normal',
+    '_npi_normal_n',
+    '_npi_not_equal',
+    '_npi_not_equal_scalar',
+    '_npi_ones',
+    '_npi_pad',
+    '_npi_pareto',
+    '_npi_percentile',
+    '_npi_pinv',
+    '_npi_pinv_scalar_rcond',
+    '_npi_polyval',
+    '_npi_power',
+    '_npi_power_scalar',
+    '_npi_powerd',
+    '_npi_prod',
+    '_npi_qr',
+    '_npi_radians',
+    '_npi_rarctan2_scalar',
+    '_npi_rayleigh',
+    '_npi_rbitwise_left_shift_scalar',
+    '_npi_rbitwise_right_shift_scalar',
+    '_npi_rcopysign_scalar',
+    '_npi_reciprocal',
+    '_npi_repeats',
+    '_npi_rfloor_divide_scalar',
+    '_npi_rfmod_scalar',
+    '_npi_rint',
+    '_npi_rldexp_scalar',
+    '_npi_rmod_scalar',
+    '_npi_roll',
+    '_npi_rollaxis',
+    '_npi_rot90',
+    '_npi_rpower_scalar',
+    '_npi_rsubtract_scalar',
+    '_npi_rtrue_divide_scalar',
+    '_npi_share_memory',
+    '_npi_sign',
+    '_npi_sin',
+    '_npi_sinh',
+    '_npi_solve',
+    '_npi_squeeze',
+    '_npi_std',
+    '_npi_subtract_scalar',
+    '_npi_svd',
+    '_npi_tan',
+    '_npi_tensordot',
+    '_npi_tensordot_int_axes',
+    '_npi_tensorinv',
+    '_npi_tensorsolve',
+    '_npi_trace',
+    '_npi_tri',
+    '_npi_tril',
+    '_npi_tril_indices',
+    '_npi_triu',
+    '_npi_true_divide_scalar',
+    '_npi_trunc',
+    '_npi_uniform',
+    '_npi_uniform_n',
+    '_npi_unique',
+    '_npi_var',
+    '_npi_vstack',
+    '_npi_weibull',
+    '_npi_where_lscalar',
+    '_npi_where_rscalar',
+    '_npi_where_scalar2',
+    '_npi_zeros',
+    '_npx_cond',
+    '_npx_constraint_check',
+    '_npx_deformable_convolution',
+    '_npx_foreach',
+    '_npx_index_add',
+    '_npx_index_update',
+    '_npx_modulated_deformable_convolution',
+    '_npx_nonzero',
+    '_npx_quantized_reshape',
+    '_npx_quantized_transpose',
+    '_npx_relu',
+    '_npx_sigmoid',
+    '_npx_while_loop',
     '_onehot_encode',
     '_ones',
     '_plus_scalar',
+    '_power',
+    '_power_scalar',
+    '_random_binomial',
     '_random_exponential',
     '_random_exponential_like',
     '_random_gamma',
@@ -173,15 +556,27 @@ FP32_FUNCS = [
     '_random_negative_binomial_like',
     '_random_normal',
     '_random_normal_like',
+    '_random_pdf_dirichlet',
+    '_random_pdf_exponential',
+    '_random_pdf_gamma',
+    '_random_pdf_generalized_negative_binomial',
+    '_random_pdf_negative_binomial',
+    '_random_pdf_normal',
+    '_random_pdf_poisson',
+    '_random_pdf_uniform',
     '_random_poisson',
     '_random_poisson_like',
     '_random_randint',
     '_random_uniform',
     '_random_uniform_like',
     '_ravel_multi_index',
+    '_rdiv_scalar',
     '_rminus_scalar',
     '_rmod_scalar',
     '_rnn_param_concat',
+    '_rpower_scalar',
+    '_sample_binomial',
+    '_sample_categorical',
     '_sample_exponential',
     '_sample_gamma',
     '_sample_generalized_negative_binomial',
@@ -193,67 +588,128 @@ FP32_FUNCS = [
     '_sample_unique_zipfian',
     '_scatter_set_nd',
     '_set_value',
+    '_shuffle',
     '_slice_assign',
     '_slice_assign_scalar',
     '_sparse_adagrad_update',
     '_sparse_retain',
-    '_split_v2',
+    '_square_sum',
     '_unravel_index',
+    '_while_loop',
     '_zeros',
     '_zeros_without_dtype',
+    'abs',
     'adam_update',
     'all_finite',
-    # 'amp_cast',
-    # 'amp_multicast',
+    'arccos',
     'arccosh',
+    'arcsin',
     'arcsinh',
     'arctan',
+    'arctanh',
     'argmax',
     'argmax_channel',
     'argmin',
+    'argsort',
     'batch_take',
     'broadcast_axis',
+    'broadcast_equal',
+    'broadcast_greater',
+    'broadcast_greater_equal',
+    'broadcast_hypot',
+    'broadcast_lesser',
+    'broadcast_lesser_equal',
     'broadcast_like',
+    'broadcast_logical_and',
+    'broadcast_logical_or',
+    'broadcast_logical_xor',
+    'broadcast_maximum',
+    'broadcast_minimum',
+    'broadcast_mod',
+    'broadcast_not_equal',
+    'broadcast_power',
     'broadcast_to',
+    'cast_storage',
     'cbrt',
     'ceil',
+    'clip',
+    'col2im',
     'cos',
+    'cosh',
     'degrees',
     'depth_to_space',
     'diag',
+    'digamma',
+    'elemwise_div',
+    'elemwise_mul',
+    'elemwise_sub',
     'erf',
-    'expand_dims',
+    'erfinv',
+    'exp',
+    'expm1',
     'fill_element_0index',
     'fix',
     'floor',
     'ftml_update',
     'ftrl_update',
+    'gamma',
+    'gammaln',
     'gather_nd',
     'hard_sigmoid',
-    'logical_not',
+    'im2col',
+    'khatri_rao',
+    'lamb_update_phase1',
+    'lamb_update_phase2',
+    'log',
+    'log10',
+    'log1p',
+    'log2',
     'log_sigmoid',
+    'logical_not',
+    'make_loss',
+    'masked_log_softmax',
     'max',
     'min',
     'mish',
+    'moments',
+    'mp_lamb_update_phase1',
+    'mp_lamb_update_phase2',
+    'mp_nag_mom_update',
     'mp_sgd_mom_update',
     'mp_sgd_update',
     'multi_all_finite',
+    'multi_lars',
     'multi_mp_sgd_mom_update',
     'multi_mp_sgd_update',
     'multi_sgd_mom_update',
     'multi_sgd_update',
+    'multi_sum_sq',
+    'nag_mom_update',
+    'nanprod',
+    'nansum',
     'negative',
+    'norm',
     'one_hot',
     'ones_like',
     'pick',
+    'preloaded_multi_mp_sgd_mom_update',
+    'preloaded_multi_mp_sgd_update',
+    'preloaded_multi_sgd_mom_update',
+    'preloaded_multi_sgd_update',
+    'prod',
     'radians',
+    'rcbrt',
+    'reciprocal',
+    'relu',
     'repeat',
+    'reset_arrays',
     'reshape_like',
     'reverse',
     'rint',
     'rmsprop_update',
     'rmspropalex_update',
     'round',
+    'rsqrt',
     'scatter_nd',
     'sgd_mom_update',
     'sgd_update',
@@ -263,184 +719,32 @@ FP32_FUNCS = [
     'signsgd_update',
     'signum_update',
     'sin',
+    'sinh',
     'size_array',
-    'slice',
     'slice_axis',
     'slice_like',
+    'smooth_l1',
+    'softmax_cross_entropy',
+    'softmin',
     'softsign',
     'sort',
-    'space_to_depth',
+    'sqrt',
+    'square',
     'squeeze',
-    'take',
+    'tan',
+    'tanh',
     'tile',
-    'transpose',
+    'topk',
     'trunc',
     'zeros_like',
-    'broadcast_mul',
-    'IdentityAttachKLSparseReg',
-    'arccos',
-    'arcsin',
-    'cosh',
-    'erfinv',
-    'sinh',
-    'tan',
-    'arctanh',
-
-    # Exponents
-    'exp',
-    'expm1',
-    'log',
-    'log10',
-    'log2',
-    'log1p',
-
-    # Powers
-    'broadcast_power',
-    'reciprocal',
-    '_rdiv_scalar',
-    'rsqrt',
-    'rcbrt',
-    '_power',
-    '_power_scalar',
-    '_rpower_scalar',
-    '_hypot',
-    '_hypot_scalar',
-    'broadcast_hypot',
-    '_square_sum',
-    '_contrib_hawkesll',
-
-    # Reductions
-    'sum',
-    'nansum',
-    'prod',
-    'nanprod',
-    'mean',
-    'norm',
-    'softmin',
-    'khatri_rao',
-    'moments',
-
-    # Misc
-    'gamma',
-    'gammaln',
-    '_linalg_gelqf',
-    '_linalg_gemm',
-    '_linalg_gemm2',
-    '_linalg_potrf',
-    '_linalg_potri',
-    '_linalg_sumlogdiag',
-    '_linalg_syevd',
-    '_linalg_syrk',
-    '_linalg_trmm',
-    '_linalg_trsm',
-    '_linalg_makediag',
-    '_linalg_extractdiag',
-    '_linalg_maketrian',
-    '_linalg_extracttrian',
-    '_linalg_inverse',
-    '_linalg_det',
-    '_linalg_slogdet',
-    '_NDArray',
-    '_Native',
-    '_contrib_count_sketch',
-    '_contrib_SyncBatchNorm',
-    '_contrib_fft',
-    'argsort',
-    'topk',
-
-    # Neural network
-    'softmax',
-    'log_softmax',
-    'masked_softmax',
-    'masked_log_softmax',
-    'InstanceNorm',
-    'LayerNorm',
-    'GroupNorm',
-    'L2Normalization',
-    'SoftmaxActivation',
-    'softmax_cross_entropy',
-    'smooth_l1',
-    'MakeLoss',
-    'make_loss',
-    'Custom',
-    'CTCLoss',
-    '_npx_deformable_convolution',
-    '_contrib_DeformablePSROIPooling',
 ]
 
 # Functions that have to be cast to FP32 only for
 # some values of their parameters
 CONDITIONAL_FP32_FUNCS = [
-    ('Activation', 'act_type', ['softrelu']),
-    ('LeakyReLU', 'act_type', ['elu', 'selu']),
-]
-
-# Functions with multiple inputs, that need the same
-# type of all their inputs
-WIDEST_TYPE_CASTS = [
-    '_npi_add',
-    'Concat',
-    '_equal',
-    '_greater',
-    '_greater_equal',
-    '_lesser',
-    '_lesser_equal',
-    '_logical_and',
-    '_logical_or',
-    '_logical_xor',
-    '_maximum',
-    '_minimum',
-    '_mod',
-    '_not_equal',
-    'Correlation',
-    'add_n',
-    'batch_dot',
-    'broadcast_add',
-    'broadcast_div',
-    'broadcast_equal',
-    'broadcast_greater',
-    'broadcast_greater_equal',
-    'broadcast_lesser',
-    'broadcast_lesser_equal',
-    'broadcast_logical_and',
-    'broadcast_logical_or',
-    'broadcast_logical_xor',
-    'broadcast_maximum',
-    'broadcast_minimum',
-    'broadcast_mod',
-    'broadcast_not_equal',
-    'broadcast_sub',
-    'dot',
-    'elemwise_add',
-    'elemwise_div',
-    'elemwise_mul',
-    'elemwise_sub',
-    'stack',
-    '_contrib_MultiBoxDetection',
-    '_contrib_MultiBoxPrior',
-    '_contrib_MultiBoxTarget',
-    '_contrib_MultiProposal',
-    '_contrib_PSROIPooling',
-    '_contrib_Proposal',
-    '_contrib_ROIAlign',
-    '_contrib_box_iou',
-    '_contrib_box_nms',
-    '_contrib_dgl_adjacency',
-    '_contrib_dgl_csr_neighbor_non_uniform_sample',
-    '_contrib_dgl_csr_neighbor_uniform_sample',
-    '_contrib_dgl_graph_compact',
-    '_contrib_dgl_subgraph',
-    '_contrib_edge_id',
-    'where',
-    '_random_pdf_gamma',
-    '_random_pdf_exponential',
-    '_random_pdf_uniform',
-    '_random_pdf_negative_binomial',
-    '_random_pdf_generalized_negative_binomial',
-    '_random_pdf_dirichlet',
-    '_random_pdf_normal',
-    '_random_pdf_poisson',
+    ('LeakyReLU', 'act_type', ['selu']),
 ]
 
 LOSS_OUTPUT_FUNCTIONS = [
+    'SoftmaxOutput'
 ]
diff --git a/src/common/utils.h b/src/common/utils.h
index fe8413f18e..2c1f9d5758 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -929,11 +929,8 @@ inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape& 
src) {
 }
 
 inline bool is_float(const int dtype) {
-  return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == 
mshadow::kFloat16;
-}
-
-inline bool is_bfloat(const int dtype) {
-  return dtype == mshadow::kBfloat16;
+  return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == 
mshadow::kFloat16 ||
+         dtype == mshadow::kBfloat16;
 }
 
 inline bool is_int(const int dtype) {
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index fb123c18c9..489737b196 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -371,7 +371,7 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs&& 
attrs,
 
   if (get_opt_constraints() != OptConstraint::None) {
     node->attrs.dict[OPT_CONSTRAINT_ATTR] =
-        
std::to_string(static_cast<std::underlying_type_t<OptConstraint>>(get_opt_constraints()));
+        
std::to_string(static_cast<OptConstraint_int_t>(get_opt_constraints()));
   }
 
   for (uint32_t i = 0; i < outputs.size(); ++i) {
diff --git a/src/nnvm/low_precision_pass.cc b/src/nnvm/low_precision_pass.cc
index 2c2cce5b99..9b26d366f2 100644
--- a/src/nnvm/low_precision_pass.cc
+++ b/src/nnvm/low_precision_pass.cc
@@ -407,10 +407,10 @@ Graph ReducePrecision(Graph&& src) {
       }
       return;
     }
-    auto opt_constraints = 
common::flag_attr_accumulate<std::underlying_type_t<OptConstraint>>(
-        old_node->attrs, OPT_CONSTRAINT_ATTR);
+    auto opt_constraints =
+        common::flag_attr_accumulate<OptConstraint_int_t>(old_node->attrs, 
OPT_CONSTRAINT_ATTR);
     if (fp32_ops.count(old_node->op()->name) > 0 ||
-        (opt_constraints & static_cast<int>(OptConstraint::DisableAMP))) {
+        (opt_constraints & 
static_cast<OptConstraint_int_t>(OptConstraint::DisableAMP))) {
       KeepOriginalNode(old_node, node_map, &entry_map);
     } else if (target_dtype_ops.count(old_node->op()->name) > 0) {
       if (!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
&entry_map)) {
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h 
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 0b9645b8a3..7183fbded6 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -58,6 +58,15 @@
       LOG(FATAL) << "Unknown type enum " << type; \
   }
 
+// TODO(PawelGlomski-Intel): add bfloat16 for quantized ops
+#define DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER()                          
                  \
+  DMLC_DECLARE_FIELD(enabled_float_output)                                     
                  \
+      .set_default(dmlc::optional<int>())                                      
                  \
+      .add_enum("float32", mshadow::kFloat32)                                  
                  \
+      .describe(                                                               
                  \
+          "Imposed float output. Used to change the output dtype when the 
operator operates on " \
+          "low precision data - (u)int8 or lp16.")
+
 namespace mxnet {
 
 // =====  CpuEngine =======================================
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h 
b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
index 19233828dc..b7f194cef7 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
@@ -44,7 +44,7 @@ struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
 
   dmlc::optional<float> min_calib_range;  // min float value calculated from 
calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from 
calibration dataset
-  bool enable_float_output;
+  dmlc::optional<int> enabled_float_output;
   DMLC_DECLARE_PARAMETER(DNNLDotParam) {
     DMLC_DECLARE_FIELD(transpose_a)
         .describe("If true then transpose the first input before dot.")
@@ -65,9 +65,7 @@ struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
             "The maximum scalar value in the form of float32 obtained "
             "through calibration. If present, it will be used to by "
             "quantized convolution op to calculate primitive scale");
-    DMLC_DECLARE_FIELD(enable_float_output)
-        .set_default(false)
-        .describe("Whether to enable float32 output.");
+    DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
   }
 
   bool operator==(const DNNLDotParam& other) const {
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot.cc 
b/src/operator/nn/dnnl/dnnl_batch_dot.cc
index a40d55621c..f61806ce43 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_batch_dot.cc
@@ -79,7 +79,7 @@ dnnl::primitive_attr GetQuantizationAttributes(const 
DNNLDotParam& param,
                                   param.max_calib_range.value()) /
                  lhs_scale_ / rhs_scale_;
     attr.set_output_scales(0, {out_scale_});
-  } else if (param.enable_float_output) {
+  } else if (param.enabled_float_output.has_value()) {
     out_scale_ = 1.0 / lhs_scale_ / rhs_scale_;
     attr.set_output_scales(0, {out_scale_});
   }
@@ -159,7 +159,7 @@ void DNNLBatchDotFwd::Execute(const OpContext& ctx,
   CommitOutput(outputs[0], out_mem);
   DNNLStream::Get()->Submit();
 
-  if (param.quantized && !param.enable_float_output) {
+  if (param.quantized && !param.enabled_float_output.has_value()) {
     mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
     float min_output;
     float max_output;
diff --git a/src/operator/nn/dnnl/dnnl_convolution-inl.h 
b/src/operator/nn/dnnl/dnnl_convolution-inl.h
index 738be8214f..d5a45468e2 100644
--- a/src/operator/nn/dnnl/dnnl_convolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_convolution-inl.h
@@ -42,12 +42,11 @@ struct DNNLConvParam : public 
dmlc::Parameter<DNNLConvParam> {
   bool with_sum;
   bool with_postsum_act;
   bool quantized;
-  bool enable_float_output;
   bool dedup_sum;
 
   dmlc::optional<float> min_calib_range;  // min float value calculated from 
calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from 
calibration dataset
-  dmlc::optional<int> amp_out_dtype;      // mshadow dtype of a fused amp_cast 
node
+  dmlc::optional<int> enabled_float_output;
 
   DMLC_DECLARE_PARAMETER(DNNLConvParam) {
     DMLC_DECLARE_FIELD(with_bn).set_default(false).describe("Add post 
batchnorm.");
@@ -57,9 +56,6 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
         .set_default(false)
         .describe("Add post activation after sum");
     DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable 
quantization");
-    DMLC_DECLARE_FIELD(enable_float_output)
-        .set_default(false)
-        .describe("Whether to enable float32 output");
     DMLC_DECLARE_FIELD(dedup_sum).set_default(false).describe("deduplicated 
sum input");
     DMLC_DECLARE_FIELD(min_calib_range)
         .set_default(dmlc::optional<float>())
@@ -73,9 +69,7 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
             "The maximum scalar value in the form of float32 obtained "
             "through calibration. If present, it will be used to by "
             "quantized convolution op to calculate primitive scale");
-    DMLC_DECLARE_FIELD(amp_out_dtype)
-        .set_default(dmlc::optional<int>())
-            MXNET_ADD_ALL_TYPES.describe("The output type deduced from the 
fused amp_cast.");
+    DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
   }
 };
 
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h 
b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
index 976dc83fe5..f734b67878 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
+++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
@@ -41,20 +41,16 @@ namespace op {
 
 struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
   bool quantized;
-  bool enable_float_output;
   bool with_eltwise;
   bool with_sum;
   dmlc::optional<float> min_calib_range;  // min float value calculated from 
calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from 
calibration dataset
   dmlc::optional<bool> channel_wise_quantize;
-  dmlc::optional<int> amp_out_dtype;  // mshadow dtype of a fused amp_cast node
+  dmlc::optional<int> enabled_float_output;
 
   DMLC_DECLARE_PARAMETER(DNNLFCParam) {
     DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
         "Whether it's a quantized FullyConnected operator");
-    DMLC_DECLARE_FIELD(enable_float_output)
-        .set_default(false)
-        .describe("Whether to enable float32 output");
     DMLC_DECLARE_FIELD(with_eltwise)
         .set_default(false)
         .describe("Whether there's a post with_eltwise after FullyConnected 
operator");
@@ -74,9 +70,7 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
     DMLC_DECLARE_FIELD(channel_wise_quantize)
         .set_default(dmlc::optional<bool>())
         .describe("Whether support channel-wise-quantize for weight.");
-    DMLC_DECLARE_FIELD(amp_out_dtype)
-        .set_default(dmlc::optional<int>())
-            MXNET_ADD_ALL_TYPES.describe("The output type deduced from the 
fused amp_cast.");
+    DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
   }
 };
 
@@ -100,7 +94,7 @@ class FCInputIndex {
     const bool has_bias  = !full_param.default_param.no_bias;
     const bool quantized = dnnl_param.quantized;
     const bool sum_input_quantized =
-        quantized && dnnl_param.with_sum && !dnnl_param.enable_float_output;
+        quantized && dnnl_param.with_sum && 
!dnnl_param.enabled_float_output.has_value();
     const bool channel_wise = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
                               dnnl_param.channel_wise_quantize.value();
 
diff --git a/src/operator/nn/dnnl/dnnl_reduce.cc 
b/src/operator/nn/dnnl/dnnl_reduce.cc
index 36aeec15b3..c5f20abb71 100644
--- a/src/operator/nn/dnnl/dnnl_reduce.cc
+++ b/src/operator/nn/dnnl/dnnl_reduce.cc
@@ -219,7 +219,7 @@ void DNNLReduceFwd::Execute(const Tensors& tensors) const {
   auto input_mem = tensors.data.GetDNNLData();
   if (tensors.out.shape().Size() == 1) {
     // scalar result
-    auto out_mem = dnnl::memory(reduce_pd->dst_desc(), engine, 
tensors.out.data().dptr<float>());
+    auto out_mem = dnnl::memory(reduce_pd->dst_desc(), engine, 
tensors.out.data().dptr_);
     stream->RegisterPrimArgs(*reduce_fwd, {{DNNL_ARG_SRC, *input_mem}, 
{DNNL_ARG_DST, out_mem}});
   } else {
     auto desc    = reduce_pd->dst_desc();
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h 
b/src/operator/numpy/np_elemwise_broadcast_op.h
index f27b9a7772..c2db724cb8 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -233,8 +233,7 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& 
attrs,
   const TBlob& lhs = inputs[0];
   const TBlob& rhs = inputs[1];
   const TBlob& out = outputs[0];
-  if ((common::is_float(lhs.type_flag_) || common::is_bfloat(lhs.type_flag_)) 
&&
-      (common::is_float(rhs.type_flag_) || common::is_bfloat(rhs.type_flag_))) 
{
+  if ((common::is_float(lhs.type_flag_)) && 
(common::is_float(rhs.type_flag_))) {
     if (lhs.type_flag_ == out.type_flag_) {
       MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, 
rhs, out, req[0]);
     } else {
@@ -370,8 +369,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& 
attrs,
     MixedBinaryElemwiseCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, 
outputs);
   } else {
     mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
-    if ((common::is_float(lhs.type_flag_) || 
common::is_bfloat(lhs.type_flag_)) &&
-        (common::is_float(rhs.type_flag_) || 
common::is_bfloat(rhs.type_flag_))) {
+    if ((common::is_float(lhs.type_flag_)) && 
(common::is_float(rhs.type_flag_))) {
       if (lhs.type_flag_ == out.type_flag_) {
         MixedAllRealBinaryBroadcastCompute<xpu, ROP>(
             attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, 
new_lshape, new_rshape);
diff --git a/src/operator/numpy/np_true_divide-inl.h 
b/src/operator/numpy/np_true_divide-inl.h
index 45fbd643e0..5d0700f56b 100644
--- a/src/operator/numpy/np_true_divide-inl.h
+++ b/src/operator/numpy/np_true_divide-inl.h
@@ -115,8 +115,7 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs& attrs,
     }
   } else {
     // Case when types of the 2 input tensors are different
-    if ((common::is_float(lhs.type_flag_) || 
common::is_bfloat(lhs.type_flag_)) &&
-        (common::is_float(rhs.type_flag_) || 
common::is_bfloat(rhs.type_flag_))) {
+    if ((common::is_float(lhs.type_flag_)) && 
(common::is_float(rhs.type_flag_))) {
       // both lhs and rhs are float types, output type is the more precise one
       TBlob temp_tblob;
       if (lhs.type_flag_ == out.type_flag_) {
@@ -239,8 +238,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& 
attrs,
           });
         }
       } else {
-        if ((common::is_float(lhs.type_flag_) || 
common::is_bfloat(lhs.type_flag_)) &&
-            (common::is_float(rhs.type_flag_) || 
common::is_bfloat(rhs.type_flag_))) {
+        if ((common::is_float(lhs.type_flag_)) && 
(common::is_float(rhs.type_flag_))) {
           // lhs and rhs have different float types, the output is the more 
precise one
           TBlob temp_tblob;
           if (lhs.type_flag_ == out.type_flag_) {
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index cfff87c3c6..c8656e7967 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -102,6 +102,7 @@ struct SampleUniformParam : public 
dmlc::Parameter<SampleUniformParam>,
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -122,6 +123,7 @@ struct SampleNormalParam : public 
dmlc::Parameter<SampleNormalParam>, NormalPara
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -144,6 +146,7 @@ struct SampleGammaParam : public 
dmlc::Parameter<SampleGammaParam>, GammaParam,
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -166,6 +169,7 @@ struct SampleExponentialParam : public 
dmlc::Parameter<SampleExponentialParam>,
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -188,6 +192,7 @@ struct SamplePoissonParam : public 
dmlc::Parameter<SamplePoissonParam>,
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -210,6 +215,7 @@ struct SampleBinomialParam : public 
dmlc::Parameter<SampleBinomialParam>,
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -232,6 +238,7 @@ struct SampleNegBinomialParam : public 
dmlc::Parameter<SampleNegBinomialParam>,
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -256,6 +263,7 @@ struct SampleGenNegBinomialParam : public 
dmlc::Parameter<SampleGenNegBinomialPa
         .add_enum("float32", mshadow::kFloat32)
         .add_enum("float64", mshadow::kFloat64)
         .add_enum("float16", mshadow::kFloat16)
+        .add_enum("bfloat16", mshadow::kBfloat16)
         .set_default(-1)
         .describe(
             "DType of the output in case this can't be inferred. "
@@ -396,7 +404,7 @@ static inline void uniform_op(const nnvm::NodeAttrs& attrs,
   Tensor<xpu, 1, float> low, high;
   GetSamplingTempData<xpu, float>(param.low, param.high, ctx, &low, &high);
   UniformSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(low, high, out, pgen, s);
@@ -414,7 +422,7 @@ static inline void normal_op(const nnvm::NodeAttrs& attrs,
   Tensor<xpu, 1, float> loc, scale;
   GetSamplingTempData<xpu, float>(param.loc, param.scale, ctx, &loc, &scale);
   NormalSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(loc, scale, out, pgen, s);
@@ -433,7 +441,7 @@ static inline void gamma_op(const nnvm::NodeAttrs& attrs,
   Tensor<xpu, 1, float> alpha, beta;
   GetSamplingTempData<xpu, float>(param.alpha, param.beta, ctx, &alpha, &beta);
   GammaSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(alpha, beta, out, pgen, s);
@@ -451,7 +459,7 @@ static inline void exponential_op(const nnvm::NodeAttrs& 
attrs,
   Tensor<xpu, 1, float> lam, dummy;
   GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
   ExponentialSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(lam, out, pgen, s);
@@ -469,7 +477,7 @@ static inline void poisson_op(const nnvm::NodeAttrs& attrs,
   Tensor<xpu, 1, float> lam, dummy;
   GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
   PoissonSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(lam, out, pgen, s);
@@ -488,7 +496,7 @@ static inline void binomial_op(const nnvm::NodeAttrs& attrs,
   Tensor<xpu, 1, float> n, p;
   GetSamplingTempData<xpu, float>(param.n, param.p, ctx, &n, &p);
   BinomialSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(n, p, out, pgen, s);
@@ -507,7 +515,7 @@ static inline void neg_binomial_op(const nnvm::NodeAttrs& 
attrs,
   Tensor<xpu, 1, float> k, p;
   GetSamplingTempData<xpu, float>(param.k, param.p, ctx, &k, &p);
   NegativeBinomialSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(k, p, out, pgen, s);
@@ -528,7 +536,7 @@ static inline void gen_neg_binomial_op(const 
nnvm::NodeAttrs& attrs,
   Tensor<xpu, 1, float> mu, alpha;
   GetSamplingTempData<xpu, float>(param.mu, param.alpha, ctx, &mu, &alpha);
   GeneralizedNegativeBinomialSampler<xpu> sampler;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
     RandGenerator<xpu, OType>* pgen = 
ctx.requested[0].get_parallel_random<xpu, OType>();
     Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
     sampler.Sample(mu, alpha, out, pgen, s);
@@ -799,11 +807,11 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs,
       dtype = mxnet::common::GetDefaultDtype();
     }
   }
-  bool dtype_ok =
-      (dtype == mshadow::kFloat16) || (dtype == mshadow::kFloat32) || (dtype 
== mshadow::kFloat64);
-  CHECK(dtype_ok) << "Output type must be float16, float32, float64: dtype is 
" << dtype_out
-                  << " vs " << mshadow::kFloat16 << " or " << 
mshadow::kFloat32 << " or "
-                  << mshadow::kFloat64;
+  bool dtype_ok = dtype == mshadow::kBfloat16 || dtype == mshadow::kFloat16 ||
+                  dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64;
+  CHECK(dtype_ok) << "Output type must be bfloat16, float16, float32, float64: 
dtype is "
+                  << dtype_out << " vs " << mshadow::kBfloat16 << " or " << 
mshadow::kFloat16
+                  << " or " << mshadow::kFloat32 << " or " << 
mshadow::kFloat64;
   TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
   return true;
 }
diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc 
b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
index 6905b118ba..df50a1abdb 100644
--- a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
@@ -60,7 +60,7 @@ bool DNNLBatchDotShape(const nnvm::NodeAttrs& attrs,
   }
 
   out_shapes->at(DotOut::out) = base_out_shapes[DotOut::out];
-  if (param.quantized && !param.enable_float_output) {
+  if (param.quantized && !param.enabled_float_output.has_value()) {
     SHAPE_ASSIGN_CHECK(*out_shapes, DotOut::out_min, mshadow::Shape1(1));
     SHAPE_ASSIGN_CHECK(*out_shapes, DotOut::out_max, mshadow::Shape1(1));
   }
@@ -74,6 +74,11 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
   const DNNLDotParam& param    = nnvm::get<DNNLDotParam>(attrs.parsed);
   const size_t base_num_inputs = 2;
   if (param.quantized) {
+    if (in_types->at(DotIn::lhs) == mshadow::kBfloat16 ||
+        in_types->at(DotIn::rhs) == mshadow::kBfloat16) {
+      return false;
+    }
+
     CHECK(in_types->at(DotIn::lhs) == mshadow::kInt8 || 
in_types->at(DotIn::lhs) == mshadow::kUint8)
         << "Quantized batch-dot lhs only supports int8/uint8 input, while "
         << in_types->at(DotIn::lhs) << " is given.";
@@ -85,8 +90,8 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
       TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
     }
 
-    if (param.enable_float_output) {
-      TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+    if (param.enabled_float_output.has_value()) {
+      TYPE_ASSIGN_CHECK(*out_types, DotOut::out, 
param.enabled_float_output.value());
     } else {
       if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
         TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kInt8);
@@ -97,9 +102,22 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
       TYPE_ASSIGN_CHECK(*out_types, DotOut::out_max, mshadow::kFloat32);
     }
   } else {
-    TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kFloat32);
-    TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kFloat32);
-    TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+    if ((*in_types)[DotIn::lhs] == mshadow::kBfloat16 ||
+        (*in_types)[DotIn::rhs] == mshadow::kBfloat16) {
+      TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kBfloat16);
+      TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kBfloat16);
+      if (param.enabled_float_output.has_value()) {
+        CHECK_EQ(param.enabled_float_output.value(), mshadow::kFloat32);
+        TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+      } else {
+        TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kBfloat16);
+      }
+    } else {
+      CHECK(!param.enabled_float_output.has_value());
+      TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+    }
   }
 
   return true;
@@ -122,7 +140,7 @@ NNVM_REGISTER_OP(_sg_onednn_batch_dot)
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLDotParam>(attrs.parsed);
-      return (param.quantized && !param.enable_float_output) ? 3 : 1;
+      return (param.quantized && !param.enabled_float_output.has_value()) ? 3 
: 1;
     })
     .set_attr_parser(ParamParser<DNNLDotParam>)
     .set_attr<nnvm::FListInputNames>(
@@ -140,7 +158,7 @@ NNVM_REGISTER_OP(_sg_onednn_batch_dot)
         "FListOutputNames",
         [](const NodeAttrs& attrs) {
           auto const& param = nnvm::get<DNNLDotParam>(attrs.parsed);
-          if (param.quantized && !param.enable_float_output) {
+          if (param.quantized && !param.enabled_float_output.has_value()) {
             return std::vector<std::string>{"output", "min_output", 
"max_output"};
           } else {
             return std::vector<std::string>{"output"};
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc 
b/src/operator/subgraph/dnnl/dnnl_conv.cc
index 5b2c9ad3e0..5d440d6d3f 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -253,7 +253,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
         post_requantize_         = true;
         weight_channelwise_scale = true;
       }
-      if (dnnl_param.enable_float_output) {
+      if (dnnl_param.enabled_float_output.has_value()) {
         weight_channelwise_scale = true;
       }
       data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, 
cached_data_max_);
@@ -270,7 +270,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
       if (dnnl_param.with_sum) {
         sum_in_scale = GetQuantizeScale(inputs[in_sum].dtype(), 
cached_sum_min_, cached_sum_max_);
       }
-      if (post_requantize_ || dnnl_param.enable_float_output) {
+      if (post_requantize_ || dnnl_param.enabled_float_output.has_value()) {
         if (post_requantize_) {
           output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? 
mshadow::kUint8 : mshadow::kInt8,
                                           cached_output_min_,
@@ -387,7 +387,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
     DNNLConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), 
new_inputs, req, {output});
   }
 
-  if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
+  if (dnnl_param.quantized && !dnnl_param.enabled_float_output.has_value()) {
     *outputs[kMin].data().dptr<float>() = cached_output_min_;
     *outputs[kMax].data().dptr<float>() = cached_output_max_;
   }
@@ -470,7 +470,6 @@ static void SgDNNLConvParamParser(nnvm::NodeAttrs* attrs) {
       auto& post_act_param = (param_.full_conv_param.dnnl_param.with_act && 
!with_act) ?
                                  param_.full_conv_param.act_param :
                                  param_.full_conv_param.postsum_act_param;
-      with_act = true;
       if (node_name == "Activation") {
         const auto act_param = nnvm::get<ActivationParam>(node->attrs.parsed);
         post_act_param.alg   = GetDNNLActAlgo(act_param);
@@ -483,6 +482,7 @@ static void SgDNNLConvParamParser(nnvm::NodeAttrs* attrs) {
         post_act_param.alg    = dnnl::algorithm::eltwise_bounded_relu;
         post_act_param.alpha  = clip_param.a_max;
       }
+      with_act = true;
     }
   });
   attrs->parsed = std::move(param_);
@@ -521,7 +521,7 @@ static std::vector<std::string> 
SgDNNLConvListInputNames(const NodeAttrs& attrs)
 static std::vector<std::string> SgDNNLConvListOutputNames(const NodeAttrs& 
attrs) {
   auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
   if (param.full_conv_param.dnnl_param.quantized &&
-      !param.full_conv_param.dnnl_param.enable_float_output) {
+      !param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
     return std::vector<std::string>{"output", "output_min", "output_max"};
   } else {
     return std::vector<std::string>{"output"};
@@ -582,7 +582,7 @@ static bool SgDNNLConvInferShape(const nnvm::NodeAttrs& 
attrs,
       }
     }
     out_shapes->at(0) = base_out_shapes[0];
-    if (!param.full_conv_param.dnnl_param.enable_float_output) {
+    if (!param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
       SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
       SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
     }
@@ -636,8 +636,9 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs& 
attrs,
       }
     }
 
-    if (param.full_conv_param.dnnl_param.enable_float_output) {
-      TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+    if (param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
+      TYPE_ASSIGN_CHECK(
+          *out_types, 0, 
param.full_conv_param.dnnl_param.enabled_float_output.value());
     } else {
       if (param.full_conv_param.dnnl_param.min_calib_range.has_value() &&
           param.full_conv_param.dnnl_param.max_calib_range.has_value()) {
@@ -656,8 +657,8 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs& 
attrs,
     return result;
   } else {
     bool result = DefaultSubgraphOpType(attrs, in_types, out_types);
-    if (param.full_conv_param.dnnl_param.amp_out_dtype.has_value()) {
-      (*out_types)[0] = param.full_conv_param.dnnl_param.amp_out_dtype.value();
+    if (param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
+      (*out_types)[0] = 
param.full_conv_param.dnnl_param.enabled_float_output.value();
     }
     return result;
   }
@@ -690,7 +691,7 @@ static bool SgDNNLConvOpStorageType(const nnvm::NodeAttrs& 
attrs,
       }
     }
     out_stypes->at(0) = base_out_stypes[0];
-    if (!param.full_conv_param.dnnl_param.enable_float_output) {
+    if (!param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
       type_assign(&out_stypes->at(1), mxnet::kDefaultStorage);
       type_assign(&out_stypes->at(2), mxnet::kDefaultStorage);
     }
@@ -754,10 +755,11 @@ NNVM_REGISTER_OP(_sg_onednn_conv)
     .set_num_inputs(SgDNNLConvNumInputs)
     .set_num_outputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
-      return param.full_conv_param.dnnl_param.quantized &&
-                     !param.full_conv_param.dnnl_param.enable_float_output ?
-                 3 :
-                 1;
+      if (param.full_conv_param.dnnl_param.quantized &&
+          !param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
+        return 3;
+      }
+      return 1;
     })
     .set_attr_parser(SgDNNLConvParamParser)
     .set_attr<nnvm::FListInputNames>("FListInputNames", 
SgDNNLConvListInputNames)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc 
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 22971bf487..ae21195706 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -116,7 +116,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
   const auto& dnnl_param    = full_param_.dnnl_param;
   const bool has_bias       = !default_param.no_bias;
   const bool quantized      = dnnl_param.quantized;
-  const bool out_quantized  = dnnl_param.quantized && 
!dnnl_param.enable_float_output;
+  const bool out_quantized  = dnnl_param.quantized && 
!dnnl_param.enabled_float_output.has_value();
   const bool channel_wise   = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
                             dnnl_param.channel_wise_quantize.value();
 
@@ -253,7 +253,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
   DNNLStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
   DNNLStream::Get()->Submit();
 
-  if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
+  if (dnnl_param.quantized && !dnnl_param.enabled_float_output.has_value()) {
     float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
     float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
 
@@ -395,7 +395,7 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
     support_channelwise_scale = true;
     fuse_requantize           = true;
   }
-  if (dnnl_param.enable_float_output) {
+  if (dnnl_param.enabled_float_output.has_value()) {
     support_channelwise_scale = true;
   }
   // channel_wise  support_channelwise_scale  result
@@ -461,7 +461,7 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
 
   size_t num_channel = cached_weight_.shape()[0];
   float out_scale    = 1.0f;
-  if (fuse_requantize || dnnl_param.enable_float_output) {
+  if (fuse_requantize || dnnl_param.enabled_float_output.has_value()) {
     float tmp_scale_ = 1.0f;
     if (fuse_requantize) {
       if (dnnl_param.with_eltwise) {
@@ -513,7 +513,7 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
     out_scale = data_scale_ * weight_scales_[0];
   }
 
-  if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+  if (dnnl_param.with_sum && !dnnl_param.enabled_float_output.has_value()) {
     float sum_in_scale =
         GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, 
cached_sum_max_);
     full_param_.sum_scale = out_scale / sum_in_scale;
@@ -657,7 +657,7 @@ static std::vector<std::string> 
SgDNNLFCListInputNames(const NodeAttrs& attrs) {
         input_names.emplace_back("bias_max");
       }
     }
-    if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+    if (dnnl_param.with_sum && !dnnl_param.enabled_float_output.has_value()) {
       input_names.emplace_back("sum_min");
       input_names.emplace_back("sum_max");
     }
@@ -668,7 +668,7 @@ static std::vector<std::string> 
SgDNNLFCListInputNames(const NodeAttrs& attrs) {
 static std::vector<std::string> SgDNNLFCListOutputNames(const NodeAttrs& 
attrs) {
   auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
   if (full_param.dnnl_param.quantized) {
-    if (full_param.dnnl_param.enable_float_output)
+    if (full_param.dnnl_param.enabled_float_output.has_value())
       return std::vector<std::string>{"output"};
     else
       return std::vector<std::string>{"output", "output_min", "output_max"};
@@ -709,7 +709,7 @@ static bool SgDNNLFCInferShape(const nnvm::NodeAttrs& attrs,
     }
 
     out_shapes->at(0) = base_out_shapes[0];
-    if (!full_param.dnnl_param.enable_float_output) {
+    if (!full_param.dnnl_param.enabled_float_output.has_value()) {
       SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
       SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
     }
@@ -754,8 +754,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
       }
     }
     if (idx.IsSumExist()) {
-      if (full_param.dnnl_param.enable_float_output) {
-        TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32);
+      if (full_param.dnnl_param.enabled_float_output.has_value()) {
+        TYPE_ASSIGN_CHECK(*in_types, idx.sum, 
full_param.dnnl_param.enabled_float_output.value());
       } else {
         CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum) 
== mshadow::kUint8)
             << "QuantizedFullyConnected sum input only supports int8/uint8, 
while "
@@ -766,8 +766,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
       TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
     }
 
-    if (full_param.dnnl_param.enable_float_output) {
-      TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+    if (full_param.dnnl_param.enabled_float_output.has_value()) {
+      TYPE_ASSIGN_CHECK(*out_types, 0, 
full_param.dnnl_param.enabled_float_output.value());
     } else {
       if (full_param.dnnl_param.min_calib_range.has_value() &&
           full_param.dnnl_param.max_calib_range.has_value()) {
@@ -786,8 +786,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
     return true;
   } else {
     bool result = DefaultSubgraphOpType(attrs, in_types, out_types);
-    if (full_param.dnnl_param.amp_out_dtype.has_value()) {
-      (*out_types)[0] = full_param.dnnl_param.amp_out_dtype.value();
+    if (full_param.dnnl_param.enabled_float_output.has_value()) {
+      (*out_types)[0] = full_param.dnnl_param.enabled_float_output.value();
     }
     return result;
   }
@@ -814,7 +814,7 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs& 
attrs,
     }
 
     out_attrs->at(0) = base_out_attrs[0];
-    if (!full_param.dnnl_param.enable_float_output) {
+    if (!full_param.dnnl_param.enabled_float_output.has_value()) {
       type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
       type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
     }
@@ -890,8 +890,11 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected)
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
       auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
-      return (full_param.dnnl_param.quantized && 
!full_param.dnnl_param.enable_float_output) ? 3 :
-                                                                               
                1;
+      if (full_param.dnnl_param.quantized &&
+          !full_param.dnnl_param.enabled_float_output.has_value()) {
+        return 3;
+      }
+      return 1;
     })
     .set_attr_parser(SgDNNLFCParamParser)
     .set_attr<nnvm::FListInputNames>("FListInputNames", SgDNNLFCListInputNames)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h 
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
index 2c19b7b68d..34d3c7a86b 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
@@ -90,7 +90,7 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
     if (EndsWith(output_n->op()->name, "elemwise_add")) {
       if (quantized_) {
         auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
-        if (!fc_param.dnnl_param.enable_float_output) {
+        if (!fc_param.dnnl_param.enabled_float_output.has_value()) {
           // For quantized graph, when FC floating point output is not enabled 
elementwise add must
           // also be quantized (min and max value have to be already stored in 
elementwise add).
           CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
@@ -234,8 +234,10 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
           // sum_tensor.data    -->   fc_out.max
           // sum_tensor.min     -->   sum_tensor.min
           // sum_tensor.max     -->   sum_tensor.max
-          const int not_rotated_end =
-              (fc_param.dnnl_param.quantized && 
!fc_param.dnnl_param.enable_float_output) ? 2 : 0;
+          const int not_rotated_end = (fc_param.dnnl_param.quantized &&
+                                       
!fc_param.dnnl_param.enabled_float_output.has_value()) ?
+                                          2 :
+                                          0;
 
           std::rotate(input_entries->begin() + base_inputs - 1,
                       input_entries->end() - 1 - not_rotated_end,
diff --git a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
index 55eec1cab8..6ec7c54e38 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
@@ -128,7 +128,7 @@ class SgDNNLPostAMPProperty : public SubgraphProperty {
     });
     CHECK_NOTNULL(fuse_node);
     CHECK_NOTNULL(amp_node);
-    fuse_node->attrs.dict["amp_out_dtype"] = amp_node->attrs.dict["dtype"];
+    fuse_node->attrs.dict["enabled_float_output"] = 
amp_node->attrs.dict["dtype"];
     fuse_node->op()->attr_parser(&(fuse_node->attrs));
     return fuse_node;
   }
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 14717592b0..94c7e63085 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -217,7 +217,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
     // When only fused quantized operator and requantize, set 
min/max_cablib_range,
     // When fused quantized operator + requantize + dequantize, set dequantize 
flag to true.
     if (dequantize_node != nullptr) {
-      fuse_node->attrs.dict["enable_float_output"] = "True";
+      fuse_node->attrs.dict["enabled_float_output"] = 
type_string(mshadow::kFloat32);
     } else {
       fuse_node->attrs.dict["min_calib_range"] =
           std::to_string(requantize_param.min_calib_range.value());
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer-inl.h 
b/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
index 542c15ad36..e4ece57d48 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
@@ -29,18 +29,14 @@ namespace op {
 struct DNNLSelfAttParam : public dmlc::Parameter<DNNLSelfAttParam> {
   int heads;
   bool quantized;
-  bool enable_float_output;
-  dmlc::optional<float> min_calib_range;  // min float value calculated from 
calibration dataset
-  dmlc::optional<float> max_calib_range;  // max float value calculated from 
calibration dataset
-  dmlc::optional<int> amp_out_dtype;      // mshadow dtype of a fused amp_cast 
node
+  dmlc::optional<float> min_calib_range;     // min float value calculated 
from calibration dataset
+  dmlc::optional<float> max_calib_range;     // max float value calculated 
from calibration dataset
+  dmlc::optional<int> enabled_float_output;  // mshadow dtype of a fused 
amp_cast node
 
   DMLC_DECLARE_PARAMETER(DNNLSelfAttParam) {
     DMLC_DECLARE_FIELD(heads).describe("Set number of heads.");
     DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
         "Whether it's a quantized self attention matmul operator.");
-    DMLC_DECLARE_FIELD(enable_float_output)
-        .set_default(false)
-        .describe("Whether to enable float32 output.");
     DMLC_DECLARE_FIELD(min_calib_range)
         .set_default(dmlc::optional<float>())
         .describe(
@@ -53,9 +49,7 @@ struct DNNLSelfAttParam : public 
dmlc::Parameter<DNNLSelfAttParam> {
             "The maximum scalar value in the form of float32 obtained "
             "through calibration. If present, it will be used to by "
             "quantized self-attention op to calculate primitive scale.");
-    DMLC_DECLARE_FIELD(amp_out_dtype)
-        .set_default(dmlc::optional<int>())
-            MXNET_ADD_ALL_TYPES.describe("The output type deduced from the 
fused amp_cast.");
+    DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
   }
 };
 
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer.cc 
b/src/operator/subgraph/dnnl/dnnl_transformer.cc
index baeb7b4448..67fee5f98f 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer.cc
+++ b/src/operator/subgraph/dnnl/dnnl_transformer.cc
@@ -56,7 +56,7 @@ static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
     out_shape->resize(3);
     SHAPE_ASSIGN_CHECK(
         *out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads, 
qkv_shape[1], qkv_shape[1]}));
-    if (!params.enable_float_output) {
+    if (!params.enabled_float_output.has_value()) {
       SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1}));  // min output
       SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1}));  // max output
     }
@@ -89,9 +89,9 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs& 
attrs,
     TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
     TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
 
-    if (params.enable_float_output) {
+    if (params.enabled_float_output.has_value()) {
       CHECK_EQ(out_types->size(), 1U);
-      TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
     } else {
       CHECK_EQ(out_types->size(), 3U);
       if (params.min_calib_range.has_value() && 
params.max_calib_range.has_value()) {
@@ -109,8 +109,8 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs& 
attrs,
       TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kFloat32);
       TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
     } else if (in_types->at(0) == mshadow::kBfloat16) {
-      if (params.amp_out_dtype.has_value()) {
-        TYPE_ASSIGN_CHECK(*out_types, 0, params.amp_out_dtype.value());
+      if (params.enabled_float_output.has_value()) {
+        TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
       } else {
         TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kBfloat16);
       }
@@ -239,7 +239,7 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
       max_output_ = param_.max_calib_range.value();
       oscale      = GetQuantizeScale(out_tensor.dtype(), min_output_, 
max_output_) /
                (data_scale_ * data_scale_);
-    } else if (param_.enable_float_output) {
+    } else if (param_.enabled_float_output.has_value()) {
       oscale = 1.0f / (data_scale_ * data_scale_);
     } else {
       mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
@@ -295,7 +295,7 @@ void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
   DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
   DNNLStream::Get()->Submit();
 
-  if (param_.quantized && !param_.enable_float_output) {
+  if (param_.quantized && !param_.enabled_float_output.has_value()) {
     float* output_min = outputs[1].data().dptr<float>();
     float* output_max = outputs[2].data().dptr<float>();
 
@@ -331,7 +331,7 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
-      if (param.quantized && !param.enable_float_output) {
+      if (param.quantized && !param.enabled_float_output.has_value()) {
         return 3;
       } else {
         return 1;
@@ -354,7 +354,8 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
                                         auto const& param =
                                             
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
                                         std::vector<std::string> 
output_names{"output"};
-                                        if (param.quantized && 
!param.enable_float_output) {
+                                        if (param.quantized &&
+                                            
!param.enabled_float_output.has_value()) {
                                           
output_names.emplace_back("min_output");
                                           
output_names.emplace_back("max_output");
                                         }
@@ -407,7 +408,7 @@ static bool SgDNNLSelfAttValShape(const NodeAttrs& attrs,
         0,
         mxnet::TShape(
             {att_shape[0], att_shape[2], att_shape[1] * qkv_shape[2] / 
params.heads / QKV_NUM}));
-    if (!params.enable_float_output) {
+    if (!params.enabled_float_output.has_value()) {
       SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1}));  // min output
       SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1}));  // max output
     }
@@ -455,9 +456,9 @@ static bool SgDNNLSelfAttValInferType(const 
nnvm::NodeAttrs& attrs,
       TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
     }
 
-    if (params.enable_float_output) {
+    if (params.enabled_float_output.has_value()) {
       CHECK_EQ(out_types->size(), 1U);
-      TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
     } else {
       CHECK_EQ(out_types->size(), 3U);
       if (params.min_calib_range.has_value() && 
params.max_calib_range.has_value()) {
@@ -478,8 +479,9 @@ static bool SgDNNLSelfAttValInferType(const 
nnvm::NodeAttrs& attrs,
     } else if (in_types->at(0) == mshadow::kBfloat16 || in_types->at(1) == 
mshadow::kBfloat16) {
       TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kBfloat16);
       TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kBfloat16);
-      if (params.amp_out_dtype.has_value()) {
-        TYPE_ASSIGN_CHECK(*out_types, 0, params.amp_out_dtype.value());
+      if (params.enabled_float_output.has_value()) {
+        CHECK_EQ(params.enabled_float_output.value(), mshadow::kFloat32);
+        TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
       } else {
         TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kBfloat16);
       }
@@ -633,7 +635,7 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
       max_output_ = param_.max_calib_range.value();
       oscale      = GetQuantizeScale(out_tensor.dtype(), min_output_, 
max_output_) /
                (att_scale_ * qkv_scale_);
-    } else if (param_.enable_float_output) {
+    } else if (param_.enabled_float_output.has_value()) {
       oscale = 1.0f / (att_scale_ * qkv_scale_);
     } else {
       mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
@@ -720,7 +722,7 @@ void DNNLSelfAttValAttOp::Forward(const OpContext& ctx,
   DNNLStream::Get()->RegisterPrimArgs(*reorder_, reorder_args);
   DNNLStream::Get()->Submit();
 
-  if (param_.quantized && !param_.enable_float_output) {
+  if (param_.quantized && !param_.enabled_float_output.has_value()) {
     float* output_min = outputs[1].data().dptr<float>();
     float* output_max = outputs[2].data().dptr<float>();
 
@@ -742,7 +744,7 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_valatt)
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
       auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
-      if (param.quantized && !param.enable_float_output) {
+      if (param.quantized && !param.enabled_float_output.has_value()) {
         return 3;
       } else {
         return 1;
@@ -768,7 +770,8 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_valatt)
                                         auto const& param =
                                             
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
                                         std::vector<std::string> 
output_names{"output"};
-                                        if (param.quantized && 
!param.enable_float_output) {
+                                        if (param.quantized &&
+                                            
!param.enabled_float_output.has_value()) {
                                           
output_names.emplace_back("min_output");
                                           
output_names.emplace_back("max_output");
                                         }
diff --git a/src/operator/tensor/elemwise_unary_op.h 
b/src/operator/tensor/elemwise_unary_op.h
index b675acf9f9..83ad711fa1 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -280,7 +280,7 @@ class UnaryOp : public OpBase {
     if (mxnet::common::is_float(inputs[0].type_flag_)) {
       UnaryOp::Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
     } else {
-      MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, DType, _, {
         MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, {
           MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
             if (inputs[0].Size() != 0) {
@@ -622,7 +622,7 @@ void HardSigmoidForward(const nnvm::NodeAttrs& attrs,
   const TBlob& out_data         = outputs[0];
   const HardSigmoidParam& param = nnvm::get<HardSigmoidParam>(attrs.parsed);
   using namespace mxnet_op;
-  MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(out_data.type_flag_, DType, _, {
     MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
       Kernel<hard_sigmoid_forward<req_type>, xpu>::Launch(s,
                                                           out_data.Size(),
@@ -650,7 +650,7 @@ void HardSigmoidBackward(const nnvm::NodeAttrs& attrs,
   const TBlob& in_grad          = outputs[0];
   const HardSigmoidParam& param = nnvm::get<HardSigmoidParam>(attrs.parsed);
   using namespace mxnet_op;
-  MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(in_data.type_flag_, DType, _, {
     MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
       Kernel<hard_sigmoid_backward<req_type>, xpu>::Launch(s,
                                                            in_grad.Size(),
@@ -863,7 +863,7 @@ void NumpyNanToNumOpForward(const nnvm::NodeAttrs& attrs,
     return;
   }
 
-  MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH_EX(out_data.type_flag_, DType, _, {
     DType defaultnan = static_cast<DType>(param.nan);
     DType posinf;
     DType neginf;
diff --git a/tests/python/dnnl/op_cfg.py b/tests/python/dnnl/op_cfg.py
new file mode 100644
index 0000000000..9effb305b1
--- /dev/null
+++ b/tests/python/dnnl/op_cfg.py
@@ -0,0 +1,410 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import namedtuple
+from itertools import product
+from functools import partial
+
+import mxnet as mx
+from mxnet.base import (_is_np_op, _NP_OP_PREFIX, _NP_EXT_OP_PREFIX, 
_NP_INTERNAL_OP_PREFIX,
+                        _OP_NAME_PREFIX_LIST)
+
+PREFIX_TO_MODULE = {
+    _NP_OP_PREFIX: mx.sym.np,
+    _NP_INTERNAL_OP_PREFIX: mx.sym.np._internal,
+    _NP_EXT_OP_PREFIX: mx.sym.npx
+}
+for nd_prefix in _OP_NAME_PREFIX_LIST:
+    module_name = nd_prefix[1:-1]  # nd_prefix == '_<module_name>_'
+    PREFIX_TO_MODULE[nd_prefix] = getattr(mx.sym, module_name)
+
+CFG_BASED_ON = '__based_on__'
+CFG_SUBGRAPH = '__subgraph__'
+CFG_RTOL_ATOL = '__rtol_atol__'
+DEFAULT_SHAPE = (8,)
+
+TensorArg = namedtuple('TensorArg', ['gen_tensor'])
+CfgBasedArg = namedtuple('CfgBasedArg', ['gen_arg'])
+SubgraphCfg = namedtuple('SubgraphCfg', ['base_op', 'backend'])
+
+
+def get_op_sym_fn(op_name: str):
+    for prefix, module in PREFIX_TO_MODULE.items():
+        if op_name.startswith(prefix):
+            return getattr(module, op_name[len(prefix):])
+    try:
+        return getattr(mx.sym, op_name)
+    except AttributeError:
+        try:
+            # op with '_' prefix
+            return getattr(mx.sym, op_name[1:])
+        except AttributeError:
+            return getattr(mx.sym._internal, op_name)
+
+
+def default_tensor(dim_or_shape, dtype):
+    if isinstance(dim_or_shape, (tuple, list)):
+        shape = dim_or_shape
+        return mx.nd.random.normal(0, 1, shape, dtype)
+    dim = dim_or_shape
+    return mx.nd.random.normal(0, 1, DEFAULT_SHAPE*dim, dtype)
+
+
+def common_weight_tensor(shape, dtype, numpy=False):
+    tensor = mx.nd.random.normal(0, 0.1, shape, dtype)
+    return tensor.as_np_ndarray() if numpy else tensor
+
+
+def valatt_attention_tensor(cfg):
+    qkv = cfg['queries_keys_values']
+    batch, seq_len, _ = qkv.shape
+    heads = cfg['heads']
+    att_shape = (batch, heads, seq_len, seq_len)
+    return mx.nd.random.randint(0, 2, att_shape).astype(qkv.dtype)
+
+
+def get_all_ops_cfgs(dtype):
+    return {
+        'Convolution': {
+            'data,kernel': [
+                (default_tensor(3, dtype), (3,)),
+                (default_tensor(4, dtype), (3, 3)),
+                (default_tensor(5, dtype), (3, 3, 3))
+            ],
+            'weight': [TensorArg(common_weight_tensor)],
+            'bias': [TensorArg(common_weight_tensor)],
+            'no_bias': [False],
+            'num_filter': [8]
+        },
+        'Deconvolution': {CFG_BASED_ON: 'Convolution'},
+        'FullyConnected': {
+            'data': [default_tensor(2, dtype)],
+            'weight': [TensorArg(common_weight_tensor)],
+            'bias': [TensorArg(common_weight_tensor)],
+            'num_hidden': [3]
+        },
+        'Pooling': {
+            'data,kernel': [
+                (default_tensor(3, dtype), (3,)),
+                (default_tensor(4, dtype), (3, 3)),
+                (default_tensor(5, dtype), (3, 3, 3))
+            ],
+        },
+        '_contrib_AdaptiveAvgPooling2D': {
+            'data,kernel,output_size': [(default_tensor(4, dtype), (2, 2), (4, 
4))],
+        },
+
+        ######################################### Casting 
#########################################
+
+        'Cast': {
+            'data': [default_tensor(2, dtype)],
+            'dtype': ['bool']
+        },
+        '_contrib_quantize_v2': {
+            'data': [default_tensor(2, dtype)],
+            'min_calib_range': [CfgBasedArg(lambda cfg: 
cfg['data'].min().asscalar())],
+            'max_calib_range': [CfgBasedArg(lambda cfg: 
cfg['data'].max().asscalar())],
+            CFG_RTOL_ATOL: [(0, 0)]
+        },
+
+        ##################################### No calculations 
#####################################
+
+        'Flatten': {
+            'data': [default_tensor(2, dtype)]
+        },
+        'Concat': {
+            '0,1,dim': [
+                (default_tensor(2, dtype), default_tensor(2, dtype), 0)
+            ]
+        },
+        'Reshape': {
+            '0': [default_tensor(2, dtype)],
+            '1': [(-1,)]
+        },
+        'transpose': {
+            '0': [default_tensor(2, dtype)]
+        },
+        'expand_dims': {
+            'data': [default_tensor(2, dtype)],
+            'axis': [-1]
+        },
+        'where': {
+            'x,y,condition': [
+                (default_tensor(2, dtype),
+                 default_tensor(2, dtype),
+                 mx.nd.random.randint(0, 2, DEFAULT_SHAPE*2, 'int32'))
+            ],
+        },
+        'take': {
+            'a,indices': [
+                (default_tensor(2, dtype),
+                 mx.nd.random.randint(0, DEFAULT_SHAPE[0], (2,), 'int32'))
+            ],
+            'axis': [-1]
+        },
+        'stack': {
+            '0,1': [
+                (default_tensor(2, dtype), default_tensor(2, dtype))
+            ]
+        },
+        '_split_v2': {
+            'ary,indices_or_sections': [
+                (default_tensor(2, dtype), (2, 3))
+            ],
+        },
+        'slice': {
+            'data,begin,end': [
+                (default_tensor(2, dtype), (0, 1), (2, 4))
+            ],
+        },
+        'space_to_depth': {
+            'data,block_size': [
+                (default_tensor(4, dtype), 2)
+            ],
+        },
+        '_copy': {
+            'data': [default_tensor(2, dtype)]
+        },
+        '_npi_transpose': {CFG_BASED_ON: 'transpose'},
+        '_npi_where': {CFG_BASED_ON: 'where'},
+        '_npx_reshape': {CFG_BASED_ON: 'Reshape'},
+
+        ###################################### Normalization 
######################################
+
+        'LayerNorm': {
+            'data': [default_tensor(2, dtype)],
+            'gamma': [TensorArg(common_weight_tensor)],
+            'beta': [TensorArg(common_weight_tensor)],
+        },
+        'BatchNorm': {
+            CFG_BASED_ON: 'LayerNorm',
+            'moving_mean': [TensorArg(common_weight_tensor)],
+            'moving_var': [
+                TensorArg(lambda shape, dtype: mx.nd.random.uniform(0, 1, 
shape, dtype))
+            ],
+        },
+        '_contrib_BatchNormWithReLU': {CFG_BASED_ON: 'BatchNorm'},
+        'LRN': {
+            'data,nsize': [(default_tensor(2, dtype), 3)]
+        },
+
+        ######################################## Reduction 
########################################
+
+        'mean': {
+            '0': [default_tensor(2, dtype)],
+            'axis': [0]
+        },
+        'sum': {CFG_BASED_ON: 'mean'},
+        '_npi_mean': {CFG_BASED_ON: 'mean'},
+        '_npi_sum': {CFG_BASED_ON: 'mean'},
+
+        ######################################### Softmax 
#########################################
+
+        'softmax': {
+            'data': [
+                default_tensor(2, dtype),
+                default_tensor(4, dtype)
+            ],
+            'axis': [-1]
+        },
+        'log_softmax': {CFG_BASED_ON: 'softmax'},
+        'masked_softmax': {
+            CFG_BASED_ON: 'softmax',
+            'mask': [
+                CfgBasedArg(
+                    lambda cfg: mx.nd.random.randint(0, 2, 
cfg['data'].shape).astype('bool')
+                )
+            ],
+        },
+
+        ################################### Activation / Unary 
####################################
+
+        'Activation': {
+            'data': [default_tensor(2, dtype)],
+            'act_type': ['sigmoid', 'log_sigmoid', 'relu', 'softrelu', 'tanh', 
'mish']
+        },
+        'LeakyReLU': {
+            'data': [default_tensor(2, dtype)],
+            'act_type': ['leaky', 'elu', 'gelu']
+        },
+        '_npi_exp': {
+            '0': [default_tensor(2, dtype)]
+        },
+        '_npi_sqrt': {
+            '0': [mx.nd.random.uniform(0, 8, DEFAULT_SHAPE*2, dtype)]
+        },
+        '_npi_square': {CFG_BASED_ON: '_npi_exp'},
+        '_npi_tanh': {CFG_BASED_ON: '_npi_exp'},
+
+        ######################################### Binary 
##########################################
+
+        'dot': {
+            '0,1': [
+                (default_tensor(3, dtype), default_tensor(3, dtype))
+            ],
+        },
+        'batch_dot': {CFG_BASED_ON: 'dot'},
+        'broadcast_add': {CFG_BASED_ON: 'dot'},
+        'broadcast_div': {CFG_BASED_ON: 'dot'},
+        'broadcast_mul': {CFG_BASED_ON: 'dot'},
+        'broadcast_sub': {CFG_BASED_ON: 'dot'},
+        'elemwise_add': {CFG_BASED_ON: 'dot'},
+        '_npi_dot': {CFG_BASED_ON: 'dot'},
+        '_npi_add': {CFG_BASED_ON: 'dot'},
+        '_npi_multiply': {CFG_BASED_ON: 'dot'},
+        '_npi_subtract': {CFG_BASED_ON: 'dot'},
+        '_npi_true_divide': {CFG_BASED_ON: 'dot'},
+
+        'add_n': {CFG_BASED_ON: 'dot'},  # this is not binary, but can work as 
binary
+
+        ######################################## Subgraph 
#########################################
+
+        '_sg_onednn_conv': {
+            CFG_BASED_ON: 'Convolution',
+            CFG_SUBGRAPH: [SubgraphCfg('Convolution', 'ONEDNN')],
+            'data,kernel': [
+                (default_tensor(4, dtype), (3, 3)),
+                (default_tensor(5, dtype), (3, 3, 3))
+            ]
+        },
+        '_sg_onednn_fully_connected': {
+            CFG_BASED_ON: 'FullyConnected',
+            CFG_SUBGRAPH: [SubgraphCfg('FullyConnected', 'ONEDNN')],
+        },
+        '_sg_onednn_batch_dot': {
+            CFG_BASED_ON: 'batch_dot',
+            CFG_SUBGRAPH: [SubgraphCfg('batch_dot', 'ONEDNN')],
+        },
+        '_sg_onednn_selfatt_qk': {
+            CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk', 'ONEDNN')],
+            'queries_keys_values': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8), 
dtype)],
+            'heads': [2]
+        },
+        '_sg_onednn_selfatt_valatt': {
+            CFG_BASED_ON: '_sg_onednn_selfatt_qk',
+            CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_valatt', 'ONEDNN')],
+            'attention': [CfgBasedArg(valatt_attention_tensor)]
+        }
+    }
+
+
+def product_dict(dict_of_lists):
+    keys = dict_of_lists.keys()
+    lists = dict_of_lists.values()
+    for scenario in product(*lists):
+        yield dict(zip(keys, scenario))
+
+
+def resolve_cfg_references(args_cfg, all_ops_cfgs):
+    if len(args_cfg) == 0:
+        return {}
+    args_cfg = args_cfg.copy()
+    base_op = args_cfg.pop(CFG_BASED_ON, None)
+    base_cfg = all_ops_cfgs.get(base_op, {})
+    result_cfg = resolve_cfg_references(base_cfg, all_ops_cfgs)
+    result_cfg.update(args_cfg)
+    return result_cfg
+
+
+def get_op_cfg_generator(op_names, dtype):
+    all_ops_cfgs = get_all_ops_cfgs(dtype)
+    for op_name in set(op_names):
+        args_cfgs = all_ops_cfgs[op_name]
+        args_cfgs = resolve_cfg_references(args_cfgs, all_ops_cfgs)
+        for args_scenario in product_dict(args_cfgs):
+            yield (op_name, args_scenario)
+
+
+def get_symblock_from_args_scenario(op_name, args_scenario):
+    args_scenario = args_scenario.copy()
+    args_scenario.pop(CFG_RTOL_ATOL, None)  # not used here
+    subgraph_cfg = args_scenario.pop(CFG_SUBGRAPH, None)
+    if subgraph_cfg is None:
+        op_sym_fn = get_op_sym_fn(op_name)
+    else:
+        op_sym_fn = get_op_sym_fn(subgraph_cfg.base_op)
+
+
+    # split binded args
+    binded_args = [(k, v) for k, v in args_scenario.items() if ',' in k]
+    for arg_names, arg_cfgs in binded_args:
+        args_scenario.pop(arg_names)
+        arg_names = arg_names.replace(' ', '').split(',')
+        assert isinstance(arg_cfgs, tuple) and len(arg_cfgs) == len(arg_names)
+        for arg_name, arg_cfg in zip(arg_names, arg_cfgs):
+            assert arg_name not in args_scenario
+            args_scenario[arg_name] = arg_cfg
+
+    # generate cfg based args
+    for arg_name, arg_cfg in args_scenario.items():
+        if isinstance(arg_cfg, CfgBasedArg):
+            args_scenario[arg_name] = arg_cfg.gen_arg(args_scenario)
+
+    kw_args = {}
+    pos_args = {}
+    for arg_name, arg_cfg in args_scenario.items():
+        if isinstance(arg_cfg, (TensorArg, mx.nd.NDArray, mx.np.ndarray)):
+            arg_cfg = mx.sym.var(arg_name)
+            if _is_np_op(op_name):
+                arg_cfg = arg_cfg.as_np_ndarray()
+        if arg_name.isdigit():
+            pos_args[int(arg_name)] = arg_cfg
+        else:
+            kw_args[arg_name] = arg_cfg
+    pos_args = [pos_args[k] for k in sorted(pos_args.keys())]
+
+    sym = op_sym_fn(*pos_args, **kw_args)
+    if subgraph_cfg is not None:
+        if len(sym.list_outputs()) > 1:
+            sym = sym[0]
+        # add additional op (+1), so the graph pass can convert the tested op
+        sym = mx.sym.relu(sym).optimize_for(subgraph_cfg.backend)
+    assert op_name in sym.tojson()
+
+    args_with_shape, args_with_dtype = {}, {}
+    for arg_name, arg_cfg in args_scenario.items():
+        if isinstance(arg_cfg, (mx.nd.NDArray, mx.np.ndarray)):
+            args_with_shape[arg_name] = arg_cfg.shape
+            args_with_dtype[arg_name] = arg_cfg.dtype
+
+    infered_shapes_args, _, infered_shapes_auxs = 
sym.infer_shape(**args_with_shape)
+    infered_shapes_args = dict(zip(sym.list_arguments(), infered_shapes_args))
+    infered_shapes_auxs = dict(zip(sym.list_auxiliary_states(), 
infered_shapes_auxs))
+
+    infered_dtypes_args, _, infered_dtypes_auxs = 
sym.infer_type(**args_with_dtype)
+    infered_dtypes_args = dict(zip(sym.list_arguments(), infered_dtypes_args))
+    infered_dtypes_auxs = dict(zip(sym.list_auxiliary_states(), 
infered_dtypes_auxs))
+
+    symblock_input_data = {}
+    for arg_name in [*sym.list_arguments(), *sym.list_auxiliary_states()]:
+        tensor_cfg = args_scenario[arg_name]
+        if isinstance(tensor_cfg, TensorArg):
+            shape = infered_shapes_args.get(arg_name, 
infered_shapes_auxs.get(arg_name, None))
+            dtype = infered_dtypes_args.get(arg_name, 
infered_dtypes_auxs.get(arg_name, None))
+            tensor = tensor_cfg.gen_tensor(shape, dtype)
+        else:
+            tensor = tensor_cfg
+        symblock_input_data[arg_name] = tensor
+
+    symblock_input_syms = [mx.sym.var(name) for name in 
symblock_input_data.keys()]
+    if _is_np_op(op_name):
+        symblock_input_syms = [var.as_np_ndarray() for var in 
symblock_input_syms]
+    symblock = mx.gluon.SymbolBlock(sym, symblock_input_syms)
+    symblock.initialize()
+    assert len(symblock.collect_params()) == 0
+
+    return symblock, list(symblock_input_data.values())
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py 
b/tests/python/dnnl/subgraphs/subgraph_common.py
index a23ba3b69c..f58b025dc2 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -90,8 +90,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
     if k.find('_quantize') != -1:
       assert v['out_type'] == out_type
     if k.find(quantized_op_name) != -1:
-      if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected")
-          or quantized_op_name.startswith("quantized_sg_onednn_conv")) and 
'enable_float_output' in v:
+      if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected") 
+          or quantized_op_name.startswith("quantized_sg_onednn_conv")) and 
'enabled_float_output' in v:
         continue
       assert 'min_calib_range' in v
       assert 'max_calib_range' in v
diff --git a/tests/python/dnnl/subgraphs/test_amp_subgraph.py 
b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
index b66ea44cde..51460c8cc0 100644
--- a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
@@ -35,33 +35,16 @@ AMP_SG_PASS_NAME = 'ONEDNN_AMP'
 AMP_DTYPE = 'bfloat16'
 
 
-# Checks if amp (after the AMP_SG_PASS_NAME fuse) changes the name of tensors 
for calibration
-def check_amp_with_quantization(net, data_example, quantized_nodes):
-  net.optimize_for(data_example, backend=QUANTIZE_SG_PASS_NAME)
-  symnet = net.export(None)[0]
-  nodes = {n['name'] for n in json.loads(symnet.tojson())['nodes'] if n['op'] 
!= 'null'}
-  quant_excluded_nodes = list(nodes - set(quantized_nodes))
-
-  _, calib_tensors1 = mx.contrib.quantization._quantize_symbol(
-      symnet, mx.current_context(), excluded_symbols=quant_excluded_nodes)
-
-  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE,
-                                    excluded_sym_names=quantized_nodes, 
cast_params_offline=True,
-                                    device=mx.current_context())
-  lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME)
-  lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
-  _, calib_tensors2 = mx.contrib.quantization._quantize_symbol(
-      lp_symnet, mx.cpu(), excluded_symbols=quant_excluded_nodes)
-  assert calib_tensors1 == calib_tensors2
-
 
 def same_graph_structure(symnet_observed, symnet_expected, expected):
   nodes_obs = 
json.loads(symnet_observed.tojson(remove_amp_cast=False))['nodes']
   nodes_exp = 
json.loads(symnet_expected.tojson(remove_amp_cast=False))['nodes']
+  nodes_obs = [(node['op'], node['inputs']) for node in nodes_obs]
+  nodes_exp = [(node['op'], node['inputs']) for node in nodes_exp]
   assert (len(nodes_obs) == len(nodes_exp)) == expected
   for node_obs, node_exp in zip(nodes_obs, nodes_exp):
-    if node_obs['op'] != node_exp['op'] or node_obs['inputs'] != 
node_exp['inputs']:
-      assert expected == False
+    if node_obs != node_exp:
+      assert expected == False, '\n'.join([f'{n1} vs {n2}' for n1, n2 in 
zip(nodes_obs, nodes_exp)])
       break
 
 
@@ -87,9 +70,6 @@ def check_amp_fuse(net, data_example, expected_sym=None, 
quantized_nodes=[], rto
     lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
     same_graph_structure(lp_symnet, expected_sym, True)
 
-  # check amp with quantization
-  check_amp_with_quantization(net, data_example, quantized_nodes)
-
 
 @mx.util.use_np
 def test_amp_fc():
@@ -221,7 +201,8 @@ def test_amp_fuse_with_branch():
         out = self.fc1(x)
         out1 = self.fc2(out)
         out1 = nn.Activation('relu')(out1)
-        out2 = mx.npx.softmax(out)
+        with nn.HybridBlock.OptConstraint.disable_amp():
+          out2 = mx.npx.softmax(out)
         return out1, out2
 
   net = TestNet()
diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py 
b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
index de680ae352..0aa8396473 100644
--- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
@@ -311,7 +311,7 @@ def function_fc_add(data_shape, add_op, quantize_mode, 
fc_out_add, flatten, relu
   if quantize_mode is not None:
     attrs['fc']['quantized'] = 'true'
     if quantize_mode == 'smart':
-      attrs['fc']['enable_float_output'] = 'true'
+      attrs['fc']['enabled_float_output'] = mx.nd.get_dtype_name(mx.np.float32)
   num_hidden=10
   net = FCWithSumExample(num_hidden, add_op, fc_out_add)
   if flatten:
diff --git a/tests/python/dnnl/test_amp.py b/tests/python/dnnl/test_amp.py
index 73be9bb9a8..df0e2f7159 100644
--- a/tests/python/dnnl/test_amp.py
+++ b/tests/python/dnnl/test_amp.py
@@ -20,9 +20,18 @@ from pathlib import Path
 curr_path = Path(__file__).resolve().parent
 sys.path.insert(0, str(curr_path.parent))
 
+import pytest
 import mxnet as mx
 import amp.common as amp_common_tests
+from mxnet.test_utils import assert_almost_equal
+from mxnet.amp.lists.symbol_bf16 import (BF16_FUNCS, BF16_FP32_FUNCS, 
WIDEST_TYPE_CASTS,
+                                         CONDITIONAL_FP32_FUNCS)
 
+from op_cfg import get_op_cfg_generator, get_symblock_from_args_scenario, 
CFG_RTOL_ATOL
+
+
+ALL_BF16_OPS = BF16_FUNCS + BF16_FP32_FUNCS + WIDEST_TYPE_CASTS
+ALL_BF16_OPS += [op_name for op_name, attr_name, attr_vals in 
CONDITIONAL_FP32_FUNCS]
 
 AMP_DTYPE = 'bfloat16'
 
@@ -54,3 +63,60 @@ def test_bf16_fp32_ops_order_independence():
 @mx.util.use_np
 def test_bf16_test_node_excluding():
     amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
+
+
+def get_param_name(param):
+    if isinstance(param, (mx.nd.NDArray, mx.np.ndarray)):
+        return 'Tensor' + str(param.shape)
+    if isinstance(param, (tuple, list)):
+        return str(type(param)(get_param_name(elem) for elem in param))
+    return str(param)
+
+
+def get_test_name(param):
+    if isinstance(param, str):
+        return f'"{param}" '  # op_name
+    if isinstance(param, dict):
+        elements = []
+        for args_names, args_cfgs in param.items():
+            if isinstance(args_cfgs, tuple):
+                binded_args = args_names.split(',')
+                for arg_name, arg_val in zip(binded_args, args_cfgs):
+                    elements.append(f'"{arg_name}": {get_param_name(arg_val)}')
+            else:
+                arg_name, arg_val = args_names, args_cfgs
+                elements.append(f'"{arg_name}": {get_param_name(arg_val)}')
+        return ' ' + ', '.join(elements)
+    raise TypeError('Op configuration should only consist of its name (str) 
and arg config (dict)')
+
+
[email protected](argnames=('op_name', 'args_scenario'),
+                         argvalues=get_op_cfg_generator(ALL_BF16_OPS, 
AMP_DTYPE),
+                         ids=get_test_name)
+def test_bf16_op(op_name, args_scenario):
+    symblock, bf16_symblock_input_data = 
get_symblock_from_args_scenario(op_name, args_scenario)
+    rtol, atol = args_scenario.get(CFG_RTOL_ATOL, (0.01, None))
+    
+    fp32_symblock_input_data = []
+    for tensor in bf16_symblock_input_data:
+        if mx.nd.get_dtype_name(tensor.dtype) == 'bfloat16':
+            tensor = tensor.astype('float32')
+        fp32_symblock_input_data.append(tensor)
+
+    try:
+        bf16_outs = symblock(*bf16_symblock_input_data)
+        fp32_outs = symblock(*fp32_symblock_input_data)
+        mx.nd.waitall()
+    except mx.MXNetError as e:
+        pytest.fail(str(e))
+
+    if not isinstance(bf16_outs, (list, tuple)):
+        bf16_outs = [bf16_outs]
+    if not isinstance(fp32_outs, (list, tuple)):
+        fp32_outs = [fp32_outs]
+
+    assert any(mx.nd.get_dtype_name(tensor.dtype) == 'bfloat16'
+               for tensor in bf16_symblock_input_data + bf16_outs)
+    assert len(bf16_outs) == len(fp32_outs)
+    for bf16_out, fp32_out in zip(bf16_outs, fp32_outs):
+        assert_almost_equal(bf16_out.astype('float32'), 
fp32_out.astype('float32'), rtol, atol)

Reply via email to