This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f6d1ed1872 Improve bf16 support (#21002)
f6d1ed1872 is described below
commit f6d1ed1872b745c3a9dd944b7ad5c2ce988b014b
Author: Paweł Głomski <[email protected]>
AuthorDate: Fri Jul 15 10:56:38 2022 +0200
Improve bf16 support (#21002)
* AMP improvements + enable bf16 input for quantize_v2
* Fix sanity
* Improve tests, AMP conversion interface, fix forwad hooks
* Fix tests
* Fix imports in tests
* Use different lp16_fp32 op in test
* Add amp.disable_amp() context, fix tests
* Add tests, generalize optimization disabling
* Fix sanity
* Review fixes
* Use is_integral<>::value
* Review fixes
Change flag type to unsigned int
Add a warning for an incorrect flag attribute value
* Extend bf16 support
* Combine enable_float_output and amp_out_dtype parameters
* Add bf16 support to _dnnl_batch_dot
* Fix sanity
* Add bf16 support to all dnnl ops, add tests
* Add license
* Fix conv activation fuse, disable masked_softmax bf16 support
* Fix sanity, add softmax test cases
* Compare bf16 outputs with fp32 reference
Co-authored-by: Bartlomiej Gawrych <[email protected]>
---
3rdparty/mshadow/mshadow/base.h | 20 +-
3rdparty/mshadow/mshadow/bfloat.h | 2 +
include/mxnet/imperative.h | 1 +
python/mxnet/amp/lists/symbol_bf16.py | 688 +++++++++++++++------
src/common/utils.h | 7 +-
src/imperative/imperative.cc | 2 +-
src/nnvm/low_precision_pass.cc | 6 +-
src/operator/nn/dnnl/dnnl_base-inl.h | 9 +
src/operator/nn/dnnl/dnnl_batch_dot-inl.h | 6 +-
src/operator/nn/dnnl/dnnl_batch_dot.cc | 4 +-
src/operator/nn/dnnl/dnnl_convolution-inl.h | 10 +-
src/operator/nn/dnnl/dnnl_fully_connected-inl.h | 12 +-
src/operator/nn/dnnl/dnnl_reduce.cc | 2 +-
src/operator/numpy/np_elemwise_broadcast_op.h | 6 +-
src/operator/numpy/np_true_divide-inl.h | 6 +-
src/operator/random/sample_op.h | 34 +-
src/operator/subgraph/dnnl/dnnl_batch_dot.cc | 34 +-
src/operator/subgraph/dnnl/dnnl_conv.cc | 32 +-
src/operator/subgraph/dnnl/dnnl_fc.cc | 37 +-
.../subgraph/dnnl/dnnl_fc_sum_fuse_property.h | 8 +-
.../subgraph/dnnl/dnnl_post_amp_property.h | 2 +-
.../subgraph/dnnl/dnnl_post_quantize_property.h | 2 +-
src/operator/subgraph/dnnl/dnnl_transformer-inl.h | 14 +-
src/operator/subgraph/dnnl/dnnl_transformer.cc | 39 +-
src/operator/tensor/elemwise_unary_op.h | 8 +-
tests/python/dnnl/op_cfg.py | 410 ++++++++++++
tests/python/dnnl/subgraphs/subgraph_common.py | 4 +-
tests/python/dnnl/subgraphs/test_amp_subgraph.py | 31 +-
tests/python/dnnl/subgraphs/test_fc_subgraph.py | 2 +-
tests/python/dnnl/test_amp.py | 66 ++
30 files changed, 1152 insertions(+), 352 deletions(-)
diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h
index a160b52201..6181b3e19a 100644
--- a/3rdparty/mshadow/mshadow/base.h
+++ b/3rdparty/mshadow/mshadow/base.h
@@ -409,7 +409,7 @@ struct DataType<half::half_t> {
#endif
#endif
};
-template<>
+template <>
struct DataType<bfloat::bf16_t> {
static const int kFlag = kBfloat16;
static const int kLanes = 1;
@@ -769,6 +769,10 @@ namespace isnan_typed {
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
}
+ template <>
+ MSHADOW_XINLINE bool IsNan(volatile mshadow::bfloat::bf16_t val) {
+ return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) > MSHADOW_BF16_EXPONENT_BITS;
+ }
} // namespace isnan_typed
/*! \brief
@@ -795,6 +799,10 @@ namespace isinf_typed {
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) ==
MSHADOW_HALF_EXPONENT_BITS;
}
+ template <>
+ MSHADOW_XINLINE bool IsInf(volatile mshadow::bfloat::bf16_t val) {
+ return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) ==
MSHADOW_BF16_EXPONENT_BITS;
+ }
} // namespace isinf_typed
/*! \brief namespace for potential reducer operations */
@@ -881,6 +889,11 @@ MSHADOW_XINLINE half::half_t
NegInfValue<half::half_t>(void) {
return half::half_t::Binary(
MSHADOW_HALF_SIGN_BIT | MSHADOW_HALF_EXPONENT_BITS);
}
+/*! \brief negative infinity value of bfloat16 */
+template <>
+MSHADOW_XINLINE bfloat::bf16_t NegInfValue<bfloat::bf16_t>(void) {
+ return bfloat::bf16_t::Binary(MSHADOW_BF16_SIGN_BIT |
MSHADOW_BF16_EXPONENT_BITS);
+}
/*!
* \brief maximum value of certain types
@@ -962,6 +975,11 @@ template<>
MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
}
+/*! \brief positive infinity value of bfloat16 */
+template <>
+MSHADOW_XINLINE bfloat::bf16_t PosInfValue<bfloat::bf16_t>(void) {
+ return bfloat::bf16_t::Binary(MSHADOW_BF16_EXPONENT_BITS);
+}
} // namespace limits
diff --git a/3rdparty/mshadow/mshadow/bfloat.h
b/3rdparty/mshadow/mshadow/bfloat.h
index 94bbb4acf5..8a9f5b0aaf 100644
--- a/3rdparty/mshadow/mshadow/bfloat.h
+++ b/3rdparty/mshadow/mshadow/bfloat.h
@@ -180,6 +180,8 @@ MSHADOW_BF16_OPERATOR(bool, <=)
#define MSHADOW_BF16_MIN mshadow::bfloat::bf16_t::Binary(0xFF7F);
#define MSHADOW_BF16_MAX mshadow::bfloat::bf16_t::Binary(0x7F7F);
+#define MSHADOW_BF16_SIGN_BIT 0x8000
+#define MSHADOW_BF16_EXPONENT_BITS 0x7f80
} // namespace bfloat
} // namespace mshadow
#endif // MSHADOW_BFLOAT_H_
\ No newline at end of file
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 42876f7bf4..75cf3e2a80 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -42,6 +42,7 @@ enum class OptConstraint : unsigned int {
DisableAMP = 1 << 0
// DisableQuantization = 1 << 1
};
+using OptConstraint_int_t = std::underlying_type_t<OptConstraint>;
/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
diff --git a/python/mxnet/amp/lists/symbol_bf16.py
b/python/mxnet/amp/lists/symbol_bf16.py
index 566990c411..89ddea2820 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -40,19 +40,69 @@ if Features.instance.is_enabled('ONEDNN'):
# like image transformations or optimizers) or they
# are dtype neutral (can work in both bf16 and fp32)
BF16_FP32_FUNCS = [
- 'abs',
+ '_contrib_AdaptiveAvgPooling2D',
+ '_contrib_BatchNormWithReLU',
+ 'Activation',
'BatchNorm',
- 'clip',
- 'Flatten',
+ 'LayerNorm',
'LRN',
+ 'softmax',
+ 'log_softmax',
+ #'masked_softmax', TODO: fix segfault appearing for a 4D input tensor
'Pooling',
- 'relu',
- '_shuffle',
- 'sqrt',
- 'square',
- 'tanh',
+ '_npi_mean',
+ '_npi_sum',
+ '_npi_square',
+ '_npi_sqrt',
+ '_npi_exp',
+ '_npi_tanh',
+ '_npi_transpose',
+ '_npx_reshape',
+ '_npi_where',
+ #'_contrib_quantize_asym', # used in rnn, which is hard to convert to bf16
'_contrib_quantize_v2',
+ #'_contrib_quantize', # not used anymore
+ 'sum',
+ 'mean',
+ '_copy',
+ 'Reshape',
+ 'Flatten',
+ 'transpose',
+ 'expand_dims',
+ 'slice',
+ 'stack',
+ 'space_to_depth',
+ '_split_v2',
+
+ # no oneDNN support:
+ 'Cast',
+ 'where',
+ 'take',
]
+# 'RNN', # GetEnv("MXNET_USE_ONEDNN_RNN", 1)
+
+# Functions with multiple inputs, that need the same
+# type of all their inputs
+WIDEST_TYPE_CASTS = [
+ 'Concat',
+ 'dot',
+ 'batch_dot',
+ 'broadcast_add',
+ 'broadcast_sub',
+ 'broadcast_mul',
+ 'broadcast_div',
+ 'elemwise_add',
+ 'add_n',
+ '_npi_dot',
+ '_npi_add',
+ '_npi_multiply',
+ '_npi_subtract',
+ '_npi_true_divide',
+]
+if Features.instance.is_enabled('ONEDNN'):
+ WIDEST_TYPE_CASTS.extend([
+ '_sg_onednn_batch_dot',
+ ])
# Functions that when running with Bfloat16, the params that still need
float32.
BF16_USE_FP32_PARAMS = {
@@ -63,106 +113,439 @@ BF16_USE_FP32_PARAMS = {
# Functions that have to be cast to FP32 due to possible
# overflows
FP32_FUNCS = [
- 'RNN',
+ 'amp_cast',
+ 'amp_multicast',
+ 'masked_softmax',
'BilinearSampler',
'BlockGrad',
- 'Cast',
- 'cast_storage',
+ 'CTCLoss',
+ 'Correlation',
'Crop',
+ 'Custom',
'Dropout',
'Embedding',
'GridGenerator',
+ 'GroupNorm',
+ 'IdentityAttachKLSparseReg',
+ 'InstanceNorm',
+ 'L2Normalization',
+ 'LinearRegressionOutput',
+ 'LogisticRegressionOutput',
+ 'MAERegressionOutput',
+ 'MakeLoss',
'Pad',
+ 'RNN',
'ROIPooling',
- 'Reshape',
+ 'SVMOutput',
'SequenceLast',
'SequenceMask',
'SequenceReverse',
'SliceChannel',
+ 'SoftmaxActivation',
+ 'SoftmaxOutput',
'SpatialTransformer',
'SwapAxis',
'UpSampling',
'_CachedOp',
+ '_CachedOpThreadSafe',
'_CrossDeviceCopy',
'_CustomFunction',
+ '_NDArray',
+ '_Native',
'_NoGradient',
+ '_adabelief_update',
'_adamw_update',
'_arange',
'_cond',
- '_contrib_interleaved_matmul_selfatt_qk',
- '_contrib_interleaved_matmul_selfatt_valatt',
- '_contrib_AdaptiveAvgPooling2D',
'_contrib_BilinearResize2D',
+ '_contrib_DeformablePSROIPooling',
+ '_contrib_MultiBoxDetection',
+ '_contrib_MultiBoxPrior',
+ '_contrib_MultiBoxTarget',
+ '_contrib_MultiProposal',
+ '_contrib_PSROIPooling',
+ '_contrib_Proposal',
+ '_contrib_ROIAlign',
+ '_contrib_RROIAlign',
+ '_contrib_SyncBatchNorm',
+ '_contrib_allclose',
+ '_contrib_arange_like',
'_contrib_bipartite_matching',
+ '_contrib_boolean_mask',
+ '_contrib_box_decode',
+ '_contrib_box_encode',
+ '_contrib_box_iou',
+ '_contrib_box_nms',
+ '_contrib_calibrate_entropy',
+ '_contrib_count_sketch',
'_contrib_dequantize',
+ '_contrib_dgl_adjacency',
+ '_contrib_dgl_csr_neighbor_non_uniform_sample',
+ '_contrib_dgl_csr_neighbor_uniform_sample',
+ '_contrib_dgl_graph_compact',
+ '_contrib_dgl_subgraph',
'_contrib_div_sqrt_dim',
- '_contrib_boolean_mask',
+ '_contrib_dynamic_reshape',
+ '_contrib_edge_id',
+ '_contrib_fft',
'_contrib_getnnz',
'_contrib_gradientmultiplier',
'_contrib_group_adagrad_update',
+ '_contrib_hawkesll',
'_contrib_index_array',
'_contrib_index_copy',
+ '_contrib_interleaved_matmul_encdec_qk',
+ '_contrib_interleaved_matmul_encdec_valatt',
+ '_contrib_interleaved_matmul_selfatt_qk',
+ '_contrib_interleaved_matmul_selfatt_valatt',
+ '_contrib_intgemm_fully_connected',
+ '_contrib_intgemm_maxabsolute',
+ '_contrib_intgemm_prepare_data',
+ '_contrib_intgemm_prepare_weight',
+ '_contrib_intgemm_take_weight',
'_contrib_quadratic',
'_contrib_quantize',
'_contrib_quantize_asym',
+ '_contrib_quantized_act',
+ '_contrib_quantized_batch_norm',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
+ '_contrib_quantized_elemwise_add',
+ '_contrib_quantized_elemwise_mul',
+ '_contrib_quantized_embedding',
'_contrib_quantized_flatten',
'_contrib_quantized_fully_connected',
'_contrib_quantized_pooling',
- '_contrib_quantized_elemwise_add',
- '_contrib_quantized_act',
+ '_contrib_quantized_reshape',
'_contrib_quantized_rnn',
- '_image_crop',
- '_linspace',
+ '_contrib_quantized_transpose',
'_contrib_requantize',
- '_copy',
+ '_contrib_round_ste',
+ '_contrib_sign_ste',
+ '_contrib_sldwin_atten_context',
+ '_contrib_sldwin_atten_mask_like',
+ '_contrib_sldwin_atten_score',
'_copyto',
'_cvcopyMakeBorder',
'_cvimdecode',
'_cvimread',
'_cvimresize',
'_div_scalar',
+ '_equal',
'_equal_scalar',
'_eye',
'_foreach',
- '_while_loop',
'_full',
'_grad_add',
- '_greater_scalar',
+ '_greater',
+ '_greater_equal',
'_greater_equal_scalar',
+ '_greater_scalar',
'_histogram',
+ '_hypot',
+ '_hypot_scalar',
'_identity_with_attr_like_rhs',
'_image_adjust_lighting',
+ '_image_crop',
'_image_flip_left_right',
'_image_flip_top_bottom',
'_image_normalize',
'_image_random_brightness',
'_image_random_color_jitter',
'_image_random_contrast',
+ '_image_random_crop',
'_image_random_flip_left_right',
'_image_random_flip_top_bottom',
'_image_random_hue',
'_image_random_lighting',
+ '_image_random_resized_crop',
'_image_random_saturation',
'_image_resize',
'_image_to_tensor',
'_imdecode',
- '_lesser_scalar',
+ '_lesser',
+ '_lesser_equal',
'_lesser_equal_scalar',
+ '_lesser_scalar',
+ '_linalg_det',
+ '_linalg_extractdiag',
+ '_linalg_extracttrian',
+ '_linalg_gelqf',
+ '_linalg_gemm',
+ '_linalg_gemm2',
+ '_linalg_inverse',
+ '_linalg_makediag',
+ '_linalg_maketrian',
+ '_linalg_potrf',
+ '_linalg_potri',
+ '_linalg_slogdet',
+ '_linalg_sumlogdiag',
+ '_linalg_syevd',
+ '_linalg_syrk',
+ '_linalg_trmm',
+ '_linalg_trsm',
+ '_linspace',
+ '_logical_and',
'_logical_and_scalar',
+ '_logical_or',
'_logical_or_scalar',
+ '_logical_xor',
'_logical_xor_scalar',
+ '_maximum',
'_maximum_scalar',
+ '_minimum',
'_minimum_scalar',
'_minus_scalar',
+ '_mod',
'_mod_scalar',
+ '_mp_adabelief_update',
'_mp_adamw_update',
'_mul_scalar',
+ '_multi_adabelief_update',
+ '_multi_adamw_update',
+ '_multi_lamb_update',
+ '_multi_lans_update',
+ '_multi_mp_adabelief_update',
+ '_multi_mp_adamw_update',
+ '_multi_mp_lamb_update',
+ '_multi_mp_lans_update',
+ '_not_equal',
'_not_equal_scalar',
+ '_np_reshape',
+ '_npi_absolute',
+ '_npi_add_scalar',
+ '_npi_advanced_indexing',
+ '_npi_advanced_indexing_multiple',
+ '_npi_all',
+ '_npi_any',
+ '_npi_arange',
+ '_npi_arccos',
+ '_npi_arccosh',
+ '_npi_arcsin',
+ '_npi_arcsinh',
+ '_npi_arctan',
+ '_npi_arctan2',
+ '_npi_arctan2_scalar',
+ '_npi_arctanh',
+ '_npi_argmax',
+ '_npi_argmin',
+ '_npi_around',
+ '_npi_atleast_1d',
+ '_npi_atleast_2d',
+ '_npi_atleast_3d',
+ '_npi_average',
+ '_npi_bernoulli',
+ '_npi_bincount',
+ '_npi_bitwise_and',
+ '_npi_bitwise_and_scalar',
+ '_npi_bitwise_left_shift',
+ '_npi_bitwise_left_shift_scalar',
+ '_npi_bitwise_not',
+ '_npi_bitwise_or',
+ '_npi_bitwise_or_scalar',
+ '_npi_bitwise_right_shift',
+ '_npi_bitwise_right_shift_scalar',
+ '_npi_bitwise_xor',
+ '_npi_bitwise_xor_scalar',
+ '_npi_blackman',
+ '_npi_boolean_mask_assign_scalar',
+ '_npi_boolean_mask_assign_tensor',
+ '_npi_broadcast_to',
+ '_npi_cbrt',
+ '_npi_ceil',
+ '_npi_choice',
+ '_npi_cholesky',
+ '_npi_column_stack',
+ '_npi_copy',
+ '_npi_copysign',
+ '_npi_copysign_scalar',
+ '_npi_cos',
+ '_npi_cosh',
+ '_npi_cross',
+ '_npi_cumsum',
+ '_npi_degrees',
+ '_npi_delete',
+ '_npi_diag',
+ '_npi_diag_indices_from',
+ '_npi_diagflat',
+ '_npi_diagonal',
+ '_npi_diff',
+ '_npi_dsplit',
+ '_npi_dstack',
+ '_npi_ediff1d',
+ '_npi_eig',
+ '_npi_eigh',
+ '_npi_eigvals',
+ '_npi_eigvalsh',
+ '_npi_einsum',
+ '_npi_equal',
+ '_npi_equal_scalar',
+ '_npi_expm1',
+ '_npi_exponential',
+ '_npi_eye',
+ '_npi_fill_diagonal',
+ '_npi_fix',
+ '_npi_flip',
+ '_npi_floor',
+ '_npi_floor_divide',
+ '_npi_floor_divide_scalar',
+ '_npi_fmax',
+ '_npi_fmax_scalar',
+ '_npi_fmin',
+ '_npi_fmin_scalar',
+ '_npi_fmod',
+ '_npi_fmod_scalar',
+ '_npi_full',
+ '_npi_full_like',
+ '_npi_gamma',
+ '_npi_gcd',
+ '_npi_gcd_scalar',
+ '_npi_greater',
+ '_npi_greater_equal',
+ '_npi_greater_equal_scalar',
+ '_npi_greater_scalar',
+ '_npi_gumbel',
+ '_npi_hamming',
+ '_npi_hanning',
+ '_npi_hsplit',
+ '_npi_hstack',
+ '_npi_hypot',
+ '_npi_identity',
+ '_npi_indices',
+ '_npi_insert_scalar',
+ '_npi_insert_slice',
+ '_npi_insert_tensor',
+ '_npi_interp',
+ '_npi_isfinite',
+ '_npi_isinf',
+ '_npi_isnan',
+ '_npi_isneginf',
+ '_npi_isposinf',
+ '_npi_kron',
+ '_npi_laplace',
+ '_npi_lcm',
+ '_npi_lcm_scalar',
+ '_npi_ldexp',
+ '_npi_ldexp_scalar',
+ '_npi_less',
+ '_npi_less_equal',
+ '_npi_less_equal_scalar',
+ '_npi_less_scalar',
+ '_npi_linspace',
+ '_npi_log',
+ '_npi_log10',
+ '_npi_log1p',
+ '_npi_log2',
+ '_npi_logaddexp',
+ '_npi_logaddexp_scalar',
+ '_npi_logical_and',
+ '_npi_logical_and_scalar',
+ '_npi_logical_not',
+ '_npi_logical_or',
+ '_npi_logical_or_scalar',
+ '_npi_logical_xor',
+ '_npi_logical_xor_scalar',
+ '_npi_logistic',
+ '_npi_logspace',
+ '_npi_lstsq',
+ '_npi_matmul',
+ '_npi_matrix_rank',
+ '_npi_matrix_rank_none_tol',
+ '_npi_max',
+ '_npi_min',
+ '_npi_mod',
+ '_npi_mod_scalar',
+ '_npi_moveaxis',
+ '_npi_multinomial',
+ '_npi_multiply_scalar',
+ '_npi_nan_to_num',
+ '_npi_negative',
+ '_npi_norm',
+ '_npi_normal',
+ '_npi_normal_n',
+ '_npi_not_equal',
+ '_npi_not_equal_scalar',
+ '_npi_ones',
+ '_npi_pad',
+ '_npi_pareto',
+ '_npi_percentile',
+ '_npi_pinv',
+ '_npi_pinv_scalar_rcond',
+ '_npi_polyval',
+ '_npi_power',
+ '_npi_power_scalar',
+ '_npi_powerd',
+ '_npi_prod',
+ '_npi_qr',
+ '_npi_radians',
+ '_npi_rarctan2_scalar',
+ '_npi_rayleigh',
+ '_npi_rbitwise_left_shift_scalar',
+ '_npi_rbitwise_right_shift_scalar',
+ '_npi_rcopysign_scalar',
+ '_npi_reciprocal',
+ '_npi_repeats',
+ '_npi_rfloor_divide_scalar',
+ '_npi_rfmod_scalar',
+ '_npi_rint',
+ '_npi_rldexp_scalar',
+ '_npi_rmod_scalar',
+ '_npi_roll',
+ '_npi_rollaxis',
+ '_npi_rot90',
+ '_npi_rpower_scalar',
+ '_npi_rsubtract_scalar',
+ '_npi_rtrue_divide_scalar',
+ '_npi_share_memory',
+ '_npi_sign',
+ '_npi_sin',
+ '_npi_sinh',
+ '_npi_solve',
+ '_npi_squeeze',
+ '_npi_std',
+ '_npi_subtract_scalar',
+ '_npi_svd',
+ '_npi_tan',
+ '_npi_tensordot',
+ '_npi_tensordot_int_axes',
+ '_npi_tensorinv',
+ '_npi_tensorsolve',
+ '_npi_trace',
+ '_npi_tri',
+ '_npi_tril',
+ '_npi_tril_indices',
+ '_npi_triu',
+ '_npi_true_divide_scalar',
+ '_npi_trunc',
+ '_npi_uniform',
+ '_npi_uniform_n',
+ '_npi_unique',
+ '_npi_var',
+ '_npi_vstack',
+ '_npi_weibull',
+ '_npi_where_lscalar',
+ '_npi_where_rscalar',
+ '_npi_where_scalar2',
+ '_npi_zeros',
+ '_npx_cond',
+ '_npx_constraint_check',
+ '_npx_deformable_convolution',
+ '_npx_foreach',
+ '_npx_index_add',
+ '_npx_index_update',
+ '_npx_modulated_deformable_convolution',
+ '_npx_nonzero',
+ '_npx_quantized_reshape',
+ '_npx_quantized_transpose',
+ '_npx_relu',
+ '_npx_sigmoid',
+ '_npx_while_loop',
'_onehot_encode',
'_ones',
'_plus_scalar',
+ '_power',
+ '_power_scalar',
+ '_random_binomial',
'_random_exponential',
'_random_exponential_like',
'_random_gamma',
@@ -173,15 +556,27 @@ FP32_FUNCS = [
'_random_negative_binomial_like',
'_random_normal',
'_random_normal_like',
+ '_random_pdf_dirichlet',
+ '_random_pdf_exponential',
+ '_random_pdf_gamma',
+ '_random_pdf_generalized_negative_binomial',
+ '_random_pdf_negative_binomial',
+ '_random_pdf_normal',
+ '_random_pdf_poisson',
+ '_random_pdf_uniform',
'_random_poisson',
'_random_poisson_like',
'_random_randint',
'_random_uniform',
'_random_uniform_like',
'_ravel_multi_index',
+ '_rdiv_scalar',
'_rminus_scalar',
'_rmod_scalar',
'_rnn_param_concat',
+ '_rpower_scalar',
+ '_sample_binomial',
+ '_sample_categorical',
'_sample_exponential',
'_sample_gamma',
'_sample_generalized_negative_binomial',
@@ -193,67 +588,128 @@ FP32_FUNCS = [
'_sample_unique_zipfian',
'_scatter_set_nd',
'_set_value',
+ '_shuffle',
'_slice_assign',
'_slice_assign_scalar',
'_sparse_adagrad_update',
'_sparse_retain',
- '_split_v2',
+ '_square_sum',
'_unravel_index',
+ '_while_loop',
'_zeros',
'_zeros_without_dtype',
+ 'abs',
'adam_update',
'all_finite',
- # 'amp_cast',
- # 'amp_multicast',
+ 'arccos',
'arccosh',
+ 'arcsin',
'arcsinh',
'arctan',
+ 'arctanh',
'argmax',
'argmax_channel',
'argmin',
+ 'argsort',
'batch_take',
'broadcast_axis',
+ 'broadcast_equal',
+ 'broadcast_greater',
+ 'broadcast_greater_equal',
+ 'broadcast_hypot',
+ 'broadcast_lesser',
+ 'broadcast_lesser_equal',
'broadcast_like',
+ 'broadcast_logical_and',
+ 'broadcast_logical_or',
+ 'broadcast_logical_xor',
+ 'broadcast_maximum',
+ 'broadcast_minimum',
+ 'broadcast_mod',
+ 'broadcast_not_equal',
+ 'broadcast_power',
'broadcast_to',
+ 'cast_storage',
'cbrt',
'ceil',
+ 'clip',
+ 'col2im',
'cos',
+ 'cosh',
'degrees',
'depth_to_space',
'diag',
+ 'digamma',
+ 'elemwise_div',
+ 'elemwise_mul',
+ 'elemwise_sub',
'erf',
- 'expand_dims',
+ 'erfinv',
+ 'exp',
+ 'expm1',
'fill_element_0index',
'fix',
'floor',
'ftml_update',
'ftrl_update',
+ 'gamma',
+ 'gammaln',
'gather_nd',
'hard_sigmoid',
- 'logical_not',
+ 'im2col',
+ 'khatri_rao',
+ 'lamb_update_phase1',
+ 'lamb_update_phase2',
+ 'log',
+ 'log10',
+ 'log1p',
+ 'log2',
'log_sigmoid',
+ 'logical_not',
+ 'make_loss',
+ 'masked_log_softmax',
'max',
'min',
'mish',
+ 'moments',
+ 'mp_lamb_update_phase1',
+ 'mp_lamb_update_phase2',
+ 'mp_nag_mom_update',
'mp_sgd_mom_update',
'mp_sgd_update',
'multi_all_finite',
+ 'multi_lars',
'multi_mp_sgd_mom_update',
'multi_mp_sgd_update',
'multi_sgd_mom_update',
'multi_sgd_update',
+ 'multi_sum_sq',
+ 'nag_mom_update',
+ 'nanprod',
+ 'nansum',
'negative',
+ 'norm',
'one_hot',
'ones_like',
'pick',
+ 'preloaded_multi_mp_sgd_mom_update',
+ 'preloaded_multi_mp_sgd_update',
+ 'preloaded_multi_sgd_mom_update',
+ 'preloaded_multi_sgd_update',
+ 'prod',
'radians',
+ 'rcbrt',
+ 'reciprocal',
+ 'relu',
'repeat',
+ 'reset_arrays',
'reshape_like',
'reverse',
'rint',
'rmsprop_update',
'rmspropalex_update',
'round',
+ 'rsqrt',
'scatter_nd',
'sgd_mom_update',
'sgd_update',
@@ -263,184 +719,32 @@ FP32_FUNCS = [
'signsgd_update',
'signum_update',
'sin',
+ 'sinh',
'size_array',
- 'slice',
'slice_axis',
'slice_like',
+ 'smooth_l1',
+ 'softmax_cross_entropy',
+ 'softmin',
'softsign',
'sort',
- 'space_to_depth',
+ 'sqrt',
+ 'square',
'squeeze',
- 'take',
+ 'tan',
+ 'tanh',
'tile',
- 'transpose',
+ 'topk',
'trunc',
'zeros_like',
- 'broadcast_mul',
- 'IdentityAttachKLSparseReg',
- 'arccos',
- 'arcsin',
- 'cosh',
- 'erfinv',
- 'sinh',
- 'tan',
- 'arctanh',
-
- # Exponents
- 'exp',
- 'expm1',
- 'log',
- 'log10',
- 'log2',
- 'log1p',
-
- # Powers
- 'broadcast_power',
- 'reciprocal',
- '_rdiv_scalar',
- 'rsqrt',
- 'rcbrt',
- '_power',
- '_power_scalar',
- '_rpower_scalar',
- '_hypot',
- '_hypot_scalar',
- 'broadcast_hypot',
- '_square_sum',
- '_contrib_hawkesll',
-
- # Reductions
- 'sum',
- 'nansum',
- 'prod',
- 'nanprod',
- 'mean',
- 'norm',
- 'softmin',
- 'khatri_rao',
- 'moments',
-
- # Misc
- 'gamma',
- 'gammaln',
- '_linalg_gelqf',
- '_linalg_gemm',
- '_linalg_gemm2',
- '_linalg_potrf',
- '_linalg_potri',
- '_linalg_sumlogdiag',
- '_linalg_syevd',
- '_linalg_syrk',
- '_linalg_trmm',
- '_linalg_trsm',
- '_linalg_makediag',
- '_linalg_extractdiag',
- '_linalg_maketrian',
- '_linalg_extracttrian',
- '_linalg_inverse',
- '_linalg_det',
- '_linalg_slogdet',
- '_NDArray',
- '_Native',
- '_contrib_count_sketch',
- '_contrib_SyncBatchNorm',
- '_contrib_fft',
- 'argsort',
- 'topk',
-
- # Neural network
- 'softmax',
- 'log_softmax',
- 'masked_softmax',
- 'masked_log_softmax',
- 'InstanceNorm',
- 'LayerNorm',
- 'GroupNorm',
- 'L2Normalization',
- 'SoftmaxActivation',
- 'softmax_cross_entropy',
- 'smooth_l1',
- 'MakeLoss',
- 'make_loss',
- 'Custom',
- 'CTCLoss',
- '_npx_deformable_convolution',
- '_contrib_DeformablePSROIPooling',
]
# Functions that have to be cast to FP32 only for
# some values of their parameters
CONDITIONAL_FP32_FUNCS = [
- ('Activation', 'act_type', ['softrelu']),
- ('LeakyReLU', 'act_type', ['elu', 'selu']),
-]
-
-# Functions with multiple inputs, that need the same
-# type of all their inputs
-WIDEST_TYPE_CASTS = [
- '_npi_add',
- 'Concat',
- '_equal',
- '_greater',
- '_greater_equal',
- '_lesser',
- '_lesser_equal',
- '_logical_and',
- '_logical_or',
- '_logical_xor',
- '_maximum',
- '_minimum',
- '_mod',
- '_not_equal',
- 'Correlation',
- 'add_n',
- 'batch_dot',
- 'broadcast_add',
- 'broadcast_div',
- 'broadcast_equal',
- 'broadcast_greater',
- 'broadcast_greater_equal',
- 'broadcast_lesser',
- 'broadcast_lesser_equal',
- 'broadcast_logical_and',
- 'broadcast_logical_or',
- 'broadcast_logical_xor',
- 'broadcast_maximum',
- 'broadcast_minimum',
- 'broadcast_mod',
- 'broadcast_not_equal',
- 'broadcast_sub',
- 'dot',
- 'elemwise_add',
- 'elemwise_div',
- 'elemwise_mul',
- 'elemwise_sub',
- 'stack',
- '_contrib_MultiBoxDetection',
- '_contrib_MultiBoxPrior',
- '_contrib_MultiBoxTarget',
- '_contrib_MultiProposal',
- '_contrib_PSROIPooling',
- '_contrib_Proposal',
- '_contrib_ROIAlign',
- '_contrib_box_iou',
- '_contrib_box_nms',
- '_contrib_dgl_adjacency',
- '_contrib_dgl_csr_neighbor_non_uniform_sample',
- '_contrib_dgl_csr_neighbor_uniform_sample',
- '_contrib_dgl_graph_compact',
- '_contrib_dgl_subgraph',
- '_contrib_edge_id',
- 'where',
- '_random_pdf_gamma',
- '_random_pdf_exponential',
- '_random_pdf_uniform',
- '_random_pdf_negative_binomial',
- '_random_pdf_generalized_negative_binomial',
- '_random_pdf_dirichlet',
- '_random_pdf_normal',
- '_random_pdf_poisson',
+ ('LeakyReLU', 'act_type', ['selu']),
]
LOSS_OUTPUT_FUNCTIONS = [
+ 'SoftmaxOutput'
]
diff --git a/src/common/utils.h b/src/common/utils.h
index fe8413f18e..2c1f9d5758 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -929,11 +929,8 @@ inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape&
src) {
}
inline bool is_float(const int dtype) {
- return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype ==
mshadow::kFloat16;
-}
-
-inline bool is_bfloat(const int dtype) {
- return dtype == mshadow::kBfloat16;
+ return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype ==
mshadow::kFloat16 ||
+ dtype == mshadow::kBfloat16;
}
inline bool is_int(const int dtype) {
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index fb123c18c9..489737b196 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -371,7 +371,7 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs&&
attrs,
if (get_opt_constraints() != OptConstraint::None) {
node->attrs.dict[OPT_CONSTRAINT_ATTR] =
-
std::to_string(static_cast<std::underlying_type_t<OptConstraint>>(get_opt_constraints()));
+
std::to_string(static_cast<OptConstraint_int_t>(get_opt_constraints()));
}
for (uint32_t i = 0; i < outputs.size(); ++i) {
diff --git a/src/nnvm/low_precision_pass.cc b/src/nnvm/low_precision_pass.cc
index 2c2cce5b99..9b26d366f2 100644
--- a/src/nnvm/low_precision_pass.cc
+++ b/src/nnvm/low_precision_pass.cc
@@ -407,10 +407,10 @@ Graph ReducePrecision(Graph&& src) {
}
return;
}
- auto opt_constraints =
common::flag_attr_accumulate<std::underlying_type_t<OptConstraint>>(
- old_node->attrs, OPT_CONSTRAINT_ATTR);
+ auto opt_constraints =
+ common::flag_attr_accumulate<OptConstraint_int_t>(old_node->attrs,
OPT_CONSTRAINT_ATTR);
if (fp32_ops.count(old_node->op()->name) > 0 ||
- (opt_constraints & static_cast<int>(OptConstraint::DisableAMP))) {
+ (opt_constraints &
static_cast<OptConstraint_int_t>(OptConstraint::DisableAMP))) {
KeepOriginalNode(old_node, node_map, &entry_map);
} else if (target_dtype_ops.count(old_node->op()->name) > 0) {
if (!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries,
&entry_map)) {
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 0b9645b8a3..7183fbded6 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -58,6 +58,15 @@
LOG(FATAL) << "Unknown type enum " << type; \
}
+// TODO(PawelGlomski-Intel): add bfloat16 for quantized ops
+#define DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER()
\
+ DMLC_DECLARE_FIELD(enabled_float_output)
\
+ .set_default(dmlc::optional<int>())
\
+ .add_enum("float32", mshadow::kFloat32)
\
+ .describe(
\
+ "Imposed float output. Used to change the output dtype when the
operator operates on " \
+ "low precision data - (u)int8 or lp16.")
+
namespace mxnet {
// ===== CpuEngine =======================================
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
index 19233828dc..b7f194cef7 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
@@ -44,7 +44,7 @@ struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
dmlc::optional<float> min_calib_range; // min float value calculated from
calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from
calibration dataset
- bool enable_float_output;
+ dmlc::optional<int> enabled_float_output;
DMLC_DECLARE_PARAMETER(DNNLDotParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
@@ -65,9 +65,7 @@ struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
"The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized convolution op to calculate primitive scale");
- DMLC_DECLARE_FIELD(enable_float_output)
- .set_default(false)
- .describe("Whether to enable float32 output.");
+ DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
}
bool operator==(const DNNLDotParam& other) const {
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot.cc
b/src/operator/nn/dnnl/dnnl_batch_dot.cc
index a40d55621c..f61806ce43 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/nn/dnnl/dnnl_batch_dot.cc
@@ -79,7 +79,7 @@ dnnl::primitive_attr GetQuantizationAttributes(const
DNNLDotParam& param,
param.max_calib_range.value()) /
lhs_scale_ / rhs_scale_;
attr.set_output_scales(0, {out_scale_});
- } else if (param.enable_float_output) {
+ } else if (param.enabled_float_output.has_value()) {
out_scale_ = 1.0 / lhs_scale_ / rhs_scale_;
attr.set_output_scales(0, {out_scale_});
}
@@ -159,7 +159,7 @@ void DNNLBatchDotFwd::Execute(const OpContext& ctx,
CommitOutput(outputs[0], out_mem);
DNNLStream::Get()->Submit();
- if (param.quantized && !param.enable_float_output) {
+ if (param.quantized && !param.enabled_float_output.has_value()) {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
float min_output;
float max_output;
diff --git a/src/operator/nn/dnnl/dnnl_convolution-inl.h
b/src/operator/nn/dnnl/dnnl_convolution-inl.h
index 738be8214f..d5a45468e2 100644
--- a/src/operator/nn/dnnl/dnnl_convolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_convolution-inl.h
@@ -42,12 +42,11 @@ struct DNNLConvParam : public
dmlc::Parameter<DNNLConvParam> {
bool with_sum;
bool with_postsum_act;
bool quantized;
- bool enable_float_output;
bool dedup_sum;
dmlc::optional<float> min_calib_range; // min float value calculated from
calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from
calibration dataset
- dmlc::optional<int> amp_out_dtype; // mshadow dtype of a fused amp_cast
node
+ dmlc::optional<int> enabled_float_output;
DMLC_DECLARE_PARAMETER(DNNLConvParam) {
DMLC_DECLARE_FIELD(with_bn).set_default(false).describe("Add post
batchnorm.");
@@ -57,9 +56,6 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
.set_default(false)
.describe("Add post activation after sum");
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable
quantization");
- DMLC_DECLARE_FIELD(enable_float_output)
- .set_default(false)
- .describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(dedup_sum).set_default(false).describe("deduplicated
sum input");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
@@ -73,9 +69,7 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
"The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized convolution op to calculate primitive scale");
- DMLC_DECLARE_FIELD(amp_out_dtype)
- .set_default(dmlc::optional<int>())
- MXNET_ADD_ALL_TYPES.describe("The output type deduced from the
fused amp_cast.");
+ DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
}
};
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
index 976dc83fe5..f734b67878 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
+++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
@@ -41,20 +41,16 @@ namespace op {
struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
bool quantized;
- bool enable_float_output;
bool with_eltwise;
bool with_sum;
dmlc::optional<float> min_calib_range; // min float value calculated from
calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from
calibration dataset
dmlc::optional<bool> channel_wise_quantize;
- dmlc::optional<int> amp_out_dtype; // mshadow dtype of a fused amp_cast node
+ dmlc::optional<int> enabled_float_output;
DMLC_DECLARE_PARAMETER(DNNLFCParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
"Whether it's a quantized FullyConnected operator");
- DMLC_DECLARE_FIELD(enable_float_output)
- .set_default(false)
- .describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_eltwise)
.set_default(false)
.describe("Whether there's a post with_eltwise after FullyConnected
operator");
@@ -74,9 +70,7 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
DMLC_DECLARE_FIELD(channel_wise_quantize)
.set_default(dmlc::optional<bool>())
.describe("Whether support channel-wise-quantize for weight.");
- DMLC_DECLARE_FIELD(amp_out_dtype)
- .set_default(dmlc::optional<int>())
- MXNET_ADD_ALL_TYPES.describe("The output type deduced from the
fused amp_cast.");
+ DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
}
};
@@ -100,7 +94,7 @@ class FCInputIndex {
const bool has_bias = !full_param.default_param.no_bias;
const bool quantized = dnnl_param.quantized;
const bool sum_input_quantized =
- quantized && dnnl_param.with_sum && !dnnl_param.enable_float_output;
+ quantized && dnnl_param.with_sum &&
!dnnl_param.enabled_float_output.has_value();
const bool channel_wise = quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize.value();
diff --git a/src/operator/nn/dnnl/dnnl_reduce.cc
b/src/operator/nn/dnnl/dnnl_reduce.cc
index 36aeec15b3..c5f20abb71 100644
--- a/src/operator/nn/dnnl/dnnl_reduce.cc
+++ b/src/operator/nn/dnnl/dnnl_reduce.cc
@@ -219,7 +219,7 @@ void DNNLReduceFwd::Execute(const Tensors& tensors) const {
auto input_mem = tensors.data.GetDNNLData();
if (tensors.out.shape().Size() == 1) {
// scalar result
- auto out_mem = dnnl::memory(reduce_pd->dst_desc(), engine,
tensors.out.data().dptr<float>());
+ auto out_mem = dnnl::memory(reduce_pd->dst_desc(), engine,
tensors.out.data().dptr_);
stream->RegisterPrimArgs(*reduce_fwd, {{DNNL_ARG_SRC, *input_mem},
{DNNL_ARG_DST, out_mem}});
} else {
auto desc = reduce_pd->dst_desc();
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h
b/src/operator/numpy/np_elemwise_broadcast_op.h
index f27b9a7772..c2db724cb8 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -233,8 +233,7 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs&
attrs,
const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];
- if ((common::is_float(lhs.type_flag_) || common::is_bfloat(lhs.type_flag_))
&&
- (common::is_float(rhs.type_flag_) || common::is_bfloat(rhs.type_flag_)))
{
+ if ((common::is_float(lhs.type_flag_)) &&
(common::is_float(rhs.type_flag_))) {
if (lhs.type_flag_ == out.type_flag_) {
MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs,
rhs, out, req[0]);
} else {
@@ -370,8 +369,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs&
attrs,
MixedBinaryElemwiseCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req,
outputs);
} else {
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
- if ((common::is_float(lhs.type_flag_) ||
common::is_bfloat(lhs.type_flag_)) &&
- (common::is_float(rhs.type_flag_) ||
common::is_bfloat(rhs.type_flag_))) {
+ if ((common::is_float(lhs.type_flag_)) &&
(common::is_float(rhs.type_flag_))) {
if (lhs.type_flag_ == out.type_flag_) {
MixedAllRealBinaryBroadcastCompute<xpu, ROP>(
attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape,
new_lshape, new_rshape);
diff --git a/src/operator/numpy/np_true_divide-inl.h
b/src/operator/numpy/np_true_divide-inl.h
index 45fbd643e0..5d0700f56b 100644
--- a/src/operator/numpy/np_true_divide-inl.h
+++ b/src/operator/numpy/np_true_divide-inl.h
@@ -115,8 +115,7 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs& attrs,
}
} else {
// Case when types of the 2 input tensors are different
- if ((common::is_float(lhs.type_flag_) ||
common::is_bfloat(lhs.type_flag_)) &&
- (common::is_float(rhs.type_flag_) ||
common::is_bfloat(rhs.type_flag_))) {
+ if ((common::is_float(lhs.type_flag_)) &&
(common::is_float(rhs.type_flag_))) {
// both lhs and rhs are float types, output type is the more precise one
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
@@ -239,8 +238,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs&
attrs,
});
}
} else {
- if ((common::is_float(lhs.type_flag_) ||
common::is_bfloat(lhs.type_flag_)) &&
- (common::is_float(rhs.type_flag_) ||
common::is_bfloat(rhs.type_flag_))) {
+ if ((common::is_float(lhs.type_flag_)) &&
(common::is_float(rhs.type_flag_))) {
// lhs and rhs have different float types, the output is the more
precise one
TBlob temp_tblob;
if (lhs.type_flag_ == out.type_flag_) {
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index cfff87c3c6..c8656e7967 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -102,6 +102,7 @@ struct SampleUniformParam : public
dmlc::Parameter<SampleUniformParam>,
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -122,6 +123,7 @@ struct SampleNormalParam : public
dmlc::Parameter<SampleNormalParam>, NormalPara
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -144,6 +146,7 @@ struct SampleGammaParam : public
dmlc::Parameter<SampleGammaParam>, GammaParam,
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -166,6 +169,7 @@ struct SampleExponentialParam : public
dmlc::Parameter<SampleExponentialParam>,
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -188,6 +192,7 @@ struct SamplePoissonParam : public
dmlc::Parameter<SamplePoissonParam>,
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -210,6 +215,7 @@ struct SampleBinomialParam : public
dmlc::Parameter<SampleBinomialParam>,
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -232,6 +238,7 @@ struct SampleNegBinomialParam : public
dmlc::Parameter<SampleNegBinomialParam>,
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -256,6 +263,7 @@ struct SampleGenNegBinomialParam : public
dmlc::Parameter<SampleGenNegBinomialPa
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
+ .add_enum("bfloat16", mshadow::kBfloat16)
.set_default(-1)
.describe(
"DType of the output in case this can't be inferred. "
@@ -396,7 +404,7 @@ static inline void uniform_op(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> low, high;
GetSamplingTempData<xpu, float>(param.low, param.high, ctx, &low, &high);
UniformSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(low, high, out, pgen, s);
@@ -414,7 +422,7 @@ static inline void normal_op(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> loc, scale;
GetSamplingTempData<xpu, float>(param.loc, param.scale, ctx, &loc, &scale);
NormalSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(loc, scale, out, pgen, s);
@@ -433,7 +441,7 @@ static inline void gamma_op(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> alpha, beta;
GetSamplingTempData<xpu, float>(param.alpha, param.beta, ctx, &alpha, &beta);
GammaSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(alpha, beta, out, pgen, s);
@@ -451,7 +459,7 @@ static inline void exponential_op(const nnvm::NodeAttrs&
attrs,
Tensor<xpu, 1, float> lam, dummy;
GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
ExponentialSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(lam, out, pgen, s);
@@ -469,7 +477,7 @@ static inline void poisson_op(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> lam, dummy;
GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
PoissonSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(lam, out, pgen, s);
@@ -488,7 +496,7 @@ static inline void binomial_op(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> n, p;
GetSamplingTempData<xpu, float>(param.n, param.p, ctx, &n, &p);
BinomialSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(n, p, out, pgen, s);
@@ -507,7 +515,7 @@ static inline void neg_binomial_op(const nnvm::NodeAttrs&
attrs,
Tensor<xpu, 1, float> k, p;
GetSamplingTempData<xpu, float>(param.k, param.p, ctx, &k, &p);
NegativeBinomialSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(k, p, out, pgen, s);
@@ -528,7 +536,7 @@ static inline void gen_neg_binomial_op(const
nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> mu, alpha;
GetSamplingTempData<xpu, float>(param.mu, param.alpha, ctx, &mu, &alpha);
GeneralizedNegativeBinomialSampler<xpu> sampler;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, OType, _, {
RandGenerator<xpu, OType>* pgen =
ctx.requested[0].get_parallel_random<xpu, OType>();
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
sampler.Sample(mu, alpha, out, pgen, s);
@@ -799,11 +807,11 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs,
dtype = mxnet::common::GetDefaultDtype();
}
}
- bool dtype_ok =
- (dtype == mshadow::kFloat16) || (dtype == mshadow::kFloat32) || (dtype
== mshadow::kFloat64);
- CHECK(dtype_ok) << "Output type must be float16, float32, float64: dtype is
" << dtype_out
- << " vs " << mshadow::kFloat16 << " or " <<
mshadow::kFloat32 << " or "
- << mshadow::kFloat64;
+ bool dtype_ok = dtype == mshadow::kBfloat16 || dtype == mshadow::kFloat16 ||
+ dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64;
+ CHECK(dtype_ok) << "Output type must be bfloat16, float16, float32, float64:
dtype is "
+ << dtype_out << " vs " << mshadow::kBfloat16 << " or " <<
mshadow::kFloat16
+ << " or " << mshadow::kFloat32 << " or " <<
mshadow::kFloat64;
TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
return true;
}
diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
index 6905b118ba..df50a1abdb 100644
--- a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
+++ b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc
@@ -60,7 +60,7 @@ bool DNNLBatchDotShape(const nnvm::NodeAttrs& attrs,
}
out_shapes->at(DotOut::out) = base_out_shapes[DotOut::out];
- if (param.quantized && !param.enable_float_output) {
+ if (param.quantized && !param.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shapes, DotOut::out_min, mshadow::Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, DotOut::out_max, mshadow::Shape1(1));
}
@@ -74,6 +74,11 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
const DNNLDotParam& param = nnvm::get<DNNLDotParam>(attrs.parsed);
const size_t base_num_inputs = 2;
if (param.quantized) {
+ if (in_types->at(DotIn::lhs) == mshadow::kBfloat16 ||
+ in_types->at(DotIn::rhs) == mshadow::kBfloat16) {
+ return false;
+ }
+
CHECK(in_types->at(DotIn::lhs) == mshadow::kInt8 ||
in_types->at(DotIn::lhs) == mshadow::kUint8)
<< "Quantized batch-dot lhs only supports int8/uint8 input, while "
<< in_types->at(DotIn::lhs) << " is given.";
@@ -85,8 +90,8 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
}
- if (param.enable_float_output) {
- TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+ if (param.enabled_float_output.has_value()) {
+ TYPE_ASSIGN_CHECK(*out_types, DotOut::out,
param.enabled_float_output.value());
} else {
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kInt8);
@@ -97,9 +102,22 @@ bool DNNLBatchDotType(const nnvm::NodeAttrs& attrs,
TYPE_ASSIGN_CHECK(*out_types, DotOut::out_max, mshadow::kFloat32);
}
} else {
- TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kFloat32);
- TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kFloat32);
- TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+ if ((*in_types)[DotIn::lhs] == mshadow::kBfloat16 ||
+ (*in_types)[DotIn::rhs] == mshadow::kBfloat16) {
+ TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kBfloat16);
+ TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kBfloat16);
+ if (param.enabled_float_output.has_value()) {
+ CHECK_EQ(param.enabled_float_output.value(), mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+ } else {
+ TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kBfloat16);
+ }
+ } else {
+ CHECK(!param.enabled_float_output.has_value());
+ TYPE_ASSIGN_CHECK(*in_types, DotIn::lhs, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_types, DotIn::rhs, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_types, DotOut::out, mshadow::kFloat32);
+ }
}
return true;
@@ -122,7 +140,7 @@ NNVM_REGISTER_OP(_sg_onednn_batch_dot)
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLDotParam>(attrs.parsed);
- return (param.quantized && !param.enable_float_output) ? 3 : 1;
+ return (param.quantized && !param.enabled_float_output.has_value()) ? 3
: 1;
})
.set_attr_parser(ParamParser<DNNLDotParam>)
.set_attr<nnvm::FListInputNames>(
@@ -140,7 +158,7 @@ NNVM_REGISTER_OP(_sg_onednn_batch_dot)
"FListOutputNames",
[](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLDotParam>(attrs.parsed);
- if (param.quantized && !param.enable_float_output) {
+ if (param.quantized && !param.enabled_float_output.has_value()) {
return std::vector<std::string>{"output", "min_output",
"max_output"};
} else {
return std::vector<std::string>{"output"};
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc
b/src/operator/subgraph/dnnl/dnnl_conv.cc
index 5b2c9ad3e0..5d440d6d3f 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -253,7 +253,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
post_requantize_ = true;
weight_channelwise_scale = true;
}
- if (dnnl_param.enable_float_output) {
+ if (dnnl_param.enabled_float_output.has_value()) {
weight_channelwise_scale = true;
}
data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_,
cached_data_max_);
@@ -270,7 +270,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (dnnl_param.with_sum) {
sum_in_scale = GetQuantizeScale(inputs[in_sum].dtype(),
cached_sum_min_, cached_sum_max_);
}
- if (post_requantize_ || dnnl_param.enable_float_output) {
+ if (post_requantize_ || dnnl_param.enabled_float_output.has_value()) {
if (post_requantize_) {
output_scale = GetQuantizeScale(IsOutputUInt8(param_) ?
mshadow::kUint8 : mshadow::kInt8,
cached_output_min_,
@@ -387,7 +387,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
DNNLConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(),
new_inputs, req, {output});
}
- if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
+ if (dnnl_param.quantized && !dnnl_param.enabled_float_output.has_value()) {
*outputs[kMin].data().dptr<float>() = cached_output_min_;
*outputs[kMax].data().dptr<float>() = cached_output_max_;
}
@@ -470,7 +470,6 @@ static void SgDNNLConvParamParser(nnvm::NodeAttrs* attrs) {
auto& post_act_param = (param_.full_conv_param.dnnl_param.with_act &&
!with_act) ?
param_.full_conv_param.act_param :
param_.full_conv_param.postsum_act_param;
- with_act = true;
if (node_name == "Activation") {
const auto act_param = nnvm::get<ActivationParam>(node->attrs.parsed);
post_act_param.alg = GetDNNLActAlgo(act_param);
@@ -483,6 +482,7 @@ static void SgDNNLConvParamParser(nnvm::NodeAttrs* attrs) {
post_act_param.alg = dnnl::algorithm::eltwise_bounded_relu;
post_act_param.alpha = clip_param.a_max;
}
+ with_act = true;
}
});
attrs->parsed = std::move(param_);
@@ -521,7 +521,7 @@ static std::vector<std::string>
SgDNNLConvListInputNames(const NodeAttrs& attrs)
static std::vector<std::string> SgDNNLConvListOutputNames(const NodeAttrs&
attrs) {
auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
if (param.full_conv_param.dnnl_param.quantized &&
- !param.full_conv_param.dnnl_param.enable_float_output) {
+ !param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
return std::vector<std::string>{"output", "output_min", "output_max"};
} else {
return std::vector<std::string>{"output"};
@@ -582,7 +582,7 @@ static bool SgDNNLConvInferShape(const nnvm::NodeAttrs&
attrs,
}
}
out_shapes->at(0) = base_out_shapes[0];
- if (!param.full_conv_param.dnnl_param.enable_float_output) {
+ if (!param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
}
@@ -636,8 +636,9 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs&
attrs,
}
}
- if (param.full_conv_param.dnnl_param.enable_float_output) {
- TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+ if (param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
+ TYPE_ASSIGN_CHECK(
+ *out_types, 0,
param.full_conv_param.dnnl_param.enabled_float_output.value());
} else {
if (param.full_conv_param.dnnl_param.min_calib_range.has_value() &&
param.full_conv_param.dnnl_param.max_calib_range.has_value()) {
@@ -656,8 +657,8 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs&
attrs,
return result;
} else {
bool result = DefaultSubgraphOpType(attrs, in_types, out_types);
- if (param.full_conv_param.dnnl_param.amp_out_dtype.has_value()) {
- (*out_types)[0] = param.full_conv_param.dnnl_param.amp_out_dtype.value();
+ if (param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
+ (*out_types)[0] =
param.full_conv_param.dnnl_param.enabled_float_output.value();
}
return result;
}
@@ -690,7 +691,7 @@ static bool SgDNNLConvOpStorageType(const nnvm::NodeAttrs&
attrs,
}
}
out_stypes->at(0) = base_out_stypes[0];
- if (!param.full_conv_param.dnnl_param.enable_float_output) {
+ if (!param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
type_assign(&out_stypes->at(1), mxnet::kDefaultStorage);
type_assign(&out_stypes->at(2), mxnet::kDefaultStorage);
}
@@ -754,10 +755,11 @@ NNVM_REGISTER_OP(_sg_onednn_conv)
.set_num_inputs(SgDNNLConvNumInputs)
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
- return param.full_conv_param.dnnl_param.quantized &&
- !param.full_conv_param.dnnl_param.enable_float_output ?
- 3 :
- 1;
+ if (param.full_conv_param.dnnl_param.quantized &&
+ !param.full_conv_param.dnnl_param.enabled_float_output.has_value()) {
+ return 3;
+ }
+ return 1;
})
.set_attr_parser(SgDNNLConvParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames",
SgDNNLConvListInputNames)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 22971bf487..ae21195706 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -116,7 +116,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
const auto& dnnl_param = full_param_.dnnl_param;
const bool has_bias = !default_param.no_bias;
const bool quantized = dnnl_param.quantized;
- const bool out_quantized = dnnl_param.quantized &&
!dnnl_param.enable_float_output;
+ const bool out_quantized = dnnl_param.quantized &&
!dnnl_param.enabled_float_output.has_value();
const bool channel_wise = quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize.value();
@@ -253,7 +253,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
DNNLStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
DNNLStream::Get()->Submit();
- if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
+ if (dnnl_param.quantized && !dnnl_param.enabled_float_output.has_value()) {
float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
@@ -395,7 +395,7 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
support_channelwise_scale = true;
fuse_requantize = true;
}
- if (dnnl_param.enable_float_output) {
+ if (dnnl_param.enabled_float_output.has_value()) {
support_channelwise_scale = true;
}
// channel_wise support_channelwise_scale result
@@ -461,7 +461,7 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
size_t num_channel = cached_weight_.shape()[0];
float out_scale = 1.0f;
- if (fuse_requantize || dnnl_param.enable_float_output) {
+ if (fuse_requantize || dnnl_param.enabled_float_output.has_value()) {
float tmp_scale_ = 1.0f;
if (fuse_requantize) {
if (dnnl_param.with_eltwise) {
@@ -513,7 +513,7 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
out_scale = data_scale_ * weight_scales_[0];
}
- if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+ if (dnnl_param.with_sum && !dnnl_param.enabled_float_output.has_value()) {
float sum_in_scale =
GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_,
cached_sum_max_);
full_param_.sum_scale = out_scale / sum_in_scale;
@@ -657,7 +657,7 @@ static std::vector<std::string>
SgDNNLFCListInputNames(const NodeAttrs& attrs) {
input_names.emplace_back("bias_max");
}
}
- if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+ if (dnnl_param.with_sum && !dnnl_param.enabled_float_output.has_value()) {
input_names.emplace_back("sum_min");
input_names.emplace_back("sum_max");
}
@@ -668,7 +668,7 @@ static std::vector<std::string>
SgDNNLFCListInputNames(const NodeAttrs& attrs) {
static std::vector<std::string> SgDNNLFCListOutputNames(const NodeAttrs&
attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized) {
- if (full_param.dnnl_param.enable_float_output)
+ if (full_param.dnnl_param.enabled_float_output.has_value())
return std::vector<std::string>{"output"};
else
return std::vector<std::string>{"output", "output_min", "output_max"};
@@ -709,7 +709,7 @@ static bool SgDNNLFCInferShape(const nnvm::NodeAttrs& attrs,
}
out_shapes->at(0) = base_out_shapes[0];
- if (!full_param.dnnl_param.enable_float_output) {
+ if (!full_param.dnnl_param.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
}
@@ -754,8 +754,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
}
}
if (idx.IsSumExist()) {
- if (full_param.dnnl_param.enable_float_output) {
- TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32);
+ if (full_param.dnnl_param.enabled_float_output.has_value()) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.sum,
full_param.dnnl_param.enabled_float_output.value());
} else {
CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum)
== mshadow::kUint8)
<< "QuantizedFullyConnected sum input only supports int8/uint8,
while "
@@ -766,8 +766,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
}
- if (full_param.dnnl_param.enable_float_output) {
- TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+ if (full_param.dnnl_param.enabled_float_output.has_value()) {
+ TYPE_ASSIGN_CHECK(*out_types, 0,
full_param.dnnl_param.enabled_float_output.value());
} else {
if (full_param.dnnl_param.min_calib_range.has_value() &&
full_param.dnnl_param.max_calib_range.has_value()) {
@@ -786,8 +786,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
return true;
} else {
bool result = DefaultSubgraphOpType(attrs, in_types, out_types);
- if (full_param.dnnl_param.amp_out_dtype.has_value()) {
- (*out_types)[0] = full_param.dnnl_param.amp_out_dtype.value();
+ if (full_param.dnnl_param.enabled_float_output.has_value()) {
+ (*out_types)[0] = full_param.dnnl_param.enabled_float_output.value();
}
return result;
}
@@ -814,7 +814,7 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs&
attrs,
}
out_attrs->at(0) = base_out_attrs[0];
- if (!full_param.dnnl_param.enable_float_output) {
+ if (!full_param.dnnl_param.enabled_float_output.has_value()) {
type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
}
@@ -890,8 +890,11 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected)
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
- return (full_param.dnnl_param.quantized &&
!full_param.dnnl_param.enable_float_output) ? 3 :
-
1;
+ if (full_param.dnnl_param.quantized &&
+ !full_param.dnnl_param.enabled_float_output.has_value()) {
+ return 3;
+ }
+ return 1;
})
.set_attr_parser(SgDNNLFCParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames", SgDNNLFCListInputNames)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
index 2c19b7b68d..34d3c7a86b 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
@@ -90,7 +90,7 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
if (EndsWith(output_n->op()->name, "elemwise_add")) {
if (quantized_) {
auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
- if (!fc_param.dnnl_param.enable_float_output) {
+ if (!fc_param.dnnl_param.enabled_float_output.has_value()) {
// For quantized graph, when FC floating point output is not enabled
elementwise add must
// also be quantized (min and max value have to be already stored in
elementwise add).
CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
@@ -234,8 +234,10 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
// sum_tensor.data --> fc_out.max
// sum_tensor.min --> sum_tensor.min
// sum_tensor.max --> sum_tensor.max
- const int not_rotated_end =
- (fc_param.dnnl_param.quantized &&
!fc_param.dnnl_param.enable_float_output) ? 2 : 0;
+ const int not_rotated_end = (fc_param.dnnl_param.quantized &&
+
!fc_param.dnnl_param.enabled_float_output.has_value()) ?
+ 2 :
+ 0;
std::rotate(input_entries->begin() + base_inputs - 1,
input_entries->end() - 1 - not_rotated_end,
diff --git a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
index 55eec1cab8..6ec7c54e38 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_amp_property.h
@@ -128,7 +128,7 @@ class SgDNNLPostAMPProperty : public SubgraphProperty {
});
CHECK_NOTNULL(fuse_node);
CHECK_NOTNULL(amp_node);
- fuse_node->attrs.dict["amp_out_dtype"] = amp_node->attrs.dict["dtype"];
+ fuse_node->attrs.dict["enabled_float_output"] =
amp_node->attrs.dict["dtype"];
fuse_node->op()->attr_parser(&(fuse_node->attrs));
return fuse_node;
}
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 14717592b0..94c7e63085 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -217,7 +217,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
// When only fused quantized operator and requantize, set
min/max_cablib_range,
// When fused quantized operator + requantize + dequantize, set dequantize
flag to true.
if (dequantize_node != nullptr) {
- fuse_node->attrs.dict["enable_float_output"] = "True";
+ fuse_node->attrs.dict["enabled_float_output"] =
type_string(mshadow::kFloat32);
} else {
fuse_node->attrs.dict["min_calib_range"] =
std::to_string(requantize_param.min_calib_range.value());
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
b/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
index 542c15ad36..e4ece57d48 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer-inl.h
@@ -29,18 +29,14 @@ namespace op {
struct DNNLSelfAttParam : public dmlc::Parameter<DNNLSelfAttParam> {
int heads;
bool quantized;
- bool enable_float_output;
- dmlc::optional<float> min_calib_range; // min float value calculated from
calibration dataset
- dmlc::optional<float> max_calib_range; // max float value calculated from
calibration dataset
- dmlc::optional<int> amp_out_dtype; // mshadow dtype of a fused amp_cast
node
+ dmlc::optional<float> min_calib_range; // min float value calculated
from calibration dataset
+ dmlc::optional<float> max_calib_range; // max float value calculated
from calibration dataset
+ dmlc::optional<int> enabled_float_output; // mshadow dtype of a fused
amp_cast node
DMLC_DECLARE_PARAMETER(DNNLSelfAttParam) {
DMLC_DECLARE_FIELD(heads).describe("Set number of heads.");
DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
"Whether it's a quantized self attention matmul operator.");
- DMLC_DECLARE_FIELD(enable_float_output)
- .set_default(false)
- .describe("Whether to enable float32 output.");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
@@ -53,9 +49,7 @@ struct DNNLSelfAttParam : public
dmlc::Parameter<DNNLSelfAttParam> {
"The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized self-attention op to calculate primitive scale.");
- DMLC_DECLARE_FIELD(amp_out_dtype)
- .set_default(dmlc::optional<int>())
- MXNET_ADD_ALL_TYPES.describe("The output type deduced from the
fused amp_cast.");
+ DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
}
};
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer.cc
b/src/operator/subgraph/dnnl/dnnl_transformer.cc
index baeb7b4448..67fee5f98f 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer.cc
+++ b/src/operator/subgraph/dnnl/dnnl_transformer.cc
@@ -56,7 +56,7 @@ static bool SgDNNLSelfAttShape(const NodeAttrs& attrs,
out_shape->resize(3);
SHAPE_ASSIGN_CHECK(
*out_shape, 0, mxnet::TShape({qkv_shape[0], params.heads,
qkv_shape[1], qkv_shape[1]}));
- if (!params.enable_float_output) {
+ if (!params.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output
SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output
}
@@ -89,9 +89,9 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs&
attrs,
TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);
- if (params.enable_float_output) {
+ if (params.enabled_float_output.has_value()) {
CHECK_EQ(out_types->size(), 1U);
- TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
} else {
CHECK_EQ(out_types->size(), 3U);
if (params.min_calib_range.has_value() &&
params.max_calib_range.has_value()) {
@@ -109,8 +109,8 @@ static bool SgDNNLSelfAttQKInferType(const nnvm::NodeAttrs&
attrs,
TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
} else if (in_types->at(0) == mshadow::kBfloat16) {
- if (params.amp_out_dtype.has_value()) {
- TYPE_ASSIGN_CHECK(*out_types, 0, params.amp_out_dtype.value());
+ if (params.enabled_float_output.has_value()) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kBfloat16);
}
@@ -239,7 +239,7 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
max_output_ = param_.max_calib_range.value();
oscale = GetQuantizeScale(out_tensor.dtype(), min_output_,
max_output_) /
(data_scale_ * data_scale_);
- } else if (param_.enable_float_output) {
+ } else if (param_.enabled_float_output.has_value()) {
oscale = 1.0f / (data_scale_ * data_scale_);
} else {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
@@ -295,7 +295,7 @@ void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
DNNLStream::Get()->Submit();
- if (param_.quantized && !param_.enable_float_output) {
+ if (param_.quantized && !param_.enabled_float_output.has_value()) {
float* output_min = outputs[1].data().dptr<float>();
float* output_max = outputs[2].data().dptr<float>();
@@ -331,7 +331,7 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- if (param.quantized && !param.enable_float_output) {
+ if (param.quantized && !param.enabled_float_output.has_value()) {
return 3;
} else {
return 1;
@@ -354,7 +354,8 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
auto const& param =
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
std::vector<std::string>
output_names{"output"};
- if (param.quantized &&
!param.enable_float_output) {
+ if (param.quantized &&
+
!param.enabled_float_output.has_value()) {
output_names.emplace_back("min_output");
output_names.emplace_back("max_output");
}
@@ -407,7 +408,7 @@ static bool SgDNNLSelfAttValShape(const NodeAttrs& attrs,
0,
mxnet::TShape(
{att_shape[0], att_shape[2], att_shape[1] * qkv_shape[2] /
params.heads / QKV_NUM}));
- if (!params.enable_float_output) {
+ if (!params.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output
SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output
}
@@ -455,9 +456,9 @@ static bool SgDNNLSelfAttValInferType(const
nnvm::NodeAttrs& attrs,
TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
}
- if (params.enable_float_output) {
+ if (params.enabled_float_output.has_value()) {
CHECK_EQ(out_types->size(), 1U);
- TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_types, 0, params.enabled_float_output.value());
} else {
CHECK_EQ(out_types->size(), 3U);
if (params.min_calib_range.has_value() &&
params.max_calib_range.has_value()) {
@@ -478,8 +479,9 @@ static bool SgDNNLSelfAttValInferType(const
nnvm::NodeAttrs& attrs,
} else if (in_types->at(0) == mshadow::kBfloat16 || in_types->at(1) ==
mshadow::kBfloat16) {
TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kBfloat16);
TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kBfloat16);
- if (params.amp_out_dtype.has_value()) {
- TYPE_ASSIGN_CHECK(*out_types, 0, params.amp_out_dtype.value());
+ if (params.enabled_float_output.has_value()) {
+ CHECK_EQ(params.enabled_float_output.value(), mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kBfloat16);
}
@@ -633,7 +635,7 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
max_output_ = param_.max_calib_range.value();
oscale = GetQuantizeScale(out_tensor.dtype(), min_output_,
max_output_) /
(att_scale_ * qkv_scale_);
- } else if (param_.enable_float_output) {
+ } else if (param_.enabled_float_output.has_value()) {
oscale = 1.0f / (att_scale_ * qkv_scale_);
} else {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
@@ -720,7 +722,7 @@ void DNNLSelfAttValAttOp::Forward(const OpContext& ctx,
DNNLStream::Get()->RegisterPrimArgs(*reorder_, reorder_args);
DNNLStream::Get()->Submit();
- if (param_.quantized && !param_.enable_float_output) {
+ if (param_.quantized && !param_.enabled_float_output.has_value()) {
float* output_min = outputs[1].data().dptr<float>();
float* output_max = outputs[2].data().dptr<float>();
@@ -742,7 +744,7 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_valatt)
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
- if (param.quantized && !param.enable_float_output) {
+ if (param.quantized && !param.enabled_float_output.has_value()) {
return 3;
} else {
return 1;
@@ -768,7 +770,8 @@ NNVM_REGISTER_OP(_sg_onednn_selfatt_valatt)
auto const& param =
nnvm::get<DNNLSelfAttParam>(attrs.parsed);
std::vector<std::string>
output_names{"output"};
- if (param.quantized &&
!param.enable_float_output) {
+ if (param.quantized &&
+
!param.enabled_float_output.has_value()) {
output_names.emplace_back("min_output");
output_names.emplace_back("max_output");
}
diff --git a/src/operator/tensor/elemwise_unary_op.h
b/src/operator/tensor/elemwise_unary_op.h
index b675acf9f9..83ad711fa1 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -280,7 +280,7 @@ class UnaryOp : public OpBase {
if (mxnet::common::is_float(inputs[0].type_flag_)) {
UnaryOp::Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else {
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(outputs[0].type_flag_, DType, _, {
MXNET_INT_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (inputs[0].Size() != 0) {
@@ -622,7 +622,7 @@ void HardSigmoidForward(const nnvm::NodeAttrs& attrs,
const TBlob& out_data = outputs[0];
const HardSigmoidParam& param = nnvm::get<HardSigmoidParam>(attrs.parsed);
using namespace mxnet_op;
- MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(out_data.type_flag_, DType, _, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<hard_sigmoid_forward<req_type>, xpu>::Launch(s,
out_data.Size(),
@@ -650,7 +650,7 @@ void HardSigmoidBackward(const nnvm::NodeAttrs& attrs,
const TBlob& in_grad = outputs[0];
const HardSigmoidParam& param = nnvm::get<HardSigmoidParam>(attrs.parsed);
using namespace mxnet_op;
- MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(in_data.type_flag_, DType, _, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<hard_sigmoid_backward<req_type>, xpu>::Launch(s,
in_grad.Size(),
@@ -863,7 +863,7 @@ void NumpyNanToNumOpForward(const nnvm::NodeAttrs& attrs,
return;
}
- MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH_EX(out_data.type_flag_, DType, _, {
DType defaultnan = static_cast<DType>(param.nan);
DType posinf;
DType neginf;
diff --git a/tests/python/dnnl/op_cfg.py b/tests/python/dnnl/op_cfg.py
new file mode 100644
index 0000000000..9effb305b1
--- /dev/null
+++ b/tests/python/dnnl/op_cfg.py
@@ -0,0 +1,410 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import namedtuple
+from itertools import product
+from functools import partial
+
+import mxnet as mx
+from mxnet.base import (_is_np_op, _NP_OP_PREFIX, _NP_EXT_OP_PREFIX,
_NP_INTERNAL_OP_PREFIX,
+ _OP_NAME_PREFIX_LIST)
+
+PREFIX_TO_MODULE = {
+ _NP_OP_PREFIX: mx.sym.np,
+ _NP_INTERNAL_OP_PREFIX: mx.sym.np._internal,
+ _NP_EXT_OP_PREFIX: mx.sym.npx
+}
+for nd_prefix in _OP_NAME_PREFIX_LIST:
+ module_name = nd_prefix[1:-1] # nd_prefix == '_<module_name>_'
+ PREFIX_TO_MODULE[nd_prefix] = getattr(mx.sym, module_name)
+
+CFG_BASED_ON = '__based_on__'
+CFG_SUBGRAPH = '__subgraph__'
+CFG_RTOL_ATOL = '__rtol_atol__'
+DEFAULT_SHAPE = (8,)
+
+TensorArg = namedtuple('TensorArg', ['gen_tensor'])
+CfgBasedArg = namedtuple('CfgBasedArg', ['gen_arg'])
+SubgraphCfg = namedtuple('SubgraphCfg', ['base_op', 'backend'])
+
+
+def get_op_sym_fn(op_name: str):
+ for prefix, module in PREFIX_TO_MODULE.items():
+ if op_name.startswith(prefix):
+ return getattr(module, op_name[len(prefix):])
+ try:
+ return getattr(mx.sym, op_name)
+ except AttributeError:
+ try:
+ # op with '_' prefix
+ return getattr(mx.sym, op_name[1:])
+ except AttributeError:
+ return getattr(mx.sym._internal, op_name)
+
+
+def default_tensor(dim_or_shape, dtype):
+ if isinstance(dim_or_shape, (tuple, list)):
+ shape = dim_or_shape
+ return mx.nd.random.normal(0, 1, shape, dtype)
+ dim = dim_or_shape
+ return mx.nd.random.normal(0, 1, DEFAULT_SHAPE*dim, dtype)
+
+
+def common_weight_tensor(shape, dtype, numpy=False):
+ tensor = mx.nd.random.normal(0, 0.1, shape, dtype)
+ return tensor.as_np_ndarray() if numpy else tensor
+
+
+def valatt_attention_tensor(cfg):
+ qkv = cfg['queries_keys_values']
+ batch, seq_len, _ = qkv.shape
+ heads = cfg['heads']
+ att_shape = (batch, heads, seq_len, seq_len)
+ return mx.nd.random.randint(0, 2, att_shape).astype(qkv.dtype)
+
+
+def get_all_ops_cfgs(dtype):
+ return {
+ 'Convolution': {
+ 'data,kernel': [
+ (default_tensor(3, dtype), (3,)),
+ (default_tensor(4, dtype), (3, 3)),
+ (default_tensor(5, dtype), (3, 3, 3))
+ ],
+ 'weight': [TensorArg(common_weight_tensor)],
+ 'bias': [TensorArg(common_weight_tensor)],
+ 'no_bias': [False],
+ 'num_filter': [8]
+ },
+ 'Deconvolution': {CFG_BASED_ON: 'Convolution'},
+ 'FullyConnected': {
+ 'data': [default_tensor(2, dtype)],
+ 'weight': [TensorArg(common_weight_tensor)],
+ 'bias': [TensorArg(common_weight_tensor)],
+ 'num_hidden': [3]
+ },
+ 'Pooling': {
+ 'data,kernel': [
+ (default_tensor(3, dtype), (3,)),
+ (default_tensor(4, dtype), (3, 3)),
+ (default_tensor(5, dtype), (3, 3, 3))
+ ],
+ },
+ '_contrib_AdaptiveAvgPooling2D': {
+ 'data,kernel,output_size': [(default_tensor(4, dtype), (2, 2), (4,
4))],
+ },
+
+ ######################################### Casting
#########################################
+
+ 'Cast': {
+ 'data': [default_tensor(2, dtype)],
+ 'dtype': ['bool']
+ },
+ '_contrib_quantize_v2': {
+ 'data': [default_tensor(2, dtype)],
+ 'min_calib_range': [CfgBasedArg(lambda cfg:
cfg['data'].min().asscalar())],
+ 'max_calib_range': [CfgBasedArg(lambda cfg:
cfg['data'].max().asscalar())],
+ CFG_RTOL_ATOL: [(0, 0)]
+ },
+
+ ##################################### No calculations
#####################################
+
+ 'Flatten': {
+ 'data': [default_tensor(2, dtype)]
+ },
+ 'Concat': {
+ '0,1,dim': [
+ (default_tensor(2, dtype), default_tensor(2, dtype), 0)
+ ]
+ },
+ 'Reshape': {
+ '0': [default_tensor(2, dtype)],
+ '1': [(-1,)]
+ },
+ 'transpose': {
+ '0': [default_tensor(2, dtype)]
+ },
+ 'expand_dims': {
+ 'data': [default_tensor(2, dtype)],
+ 'axis': [-1]
+ },
+ 'where': {
+ 'x,y,condition': [
+ (default_tensor(2, dtype),
+ default_tensor(2, dtype),
+ mx.nd.random.randint(0, 2, DEFAULT_SHAPE*2, 'int32'))
+ ],
+ },
+ 'take': {
+ 'a,indices': [
+ (default_tensor(2, dtype),
+ mx.nd.random.randint(0, DEFAULT_SHAPE[0], (2,), 'int32'))
+ ],
+ 'axis': [-1]
+ },
+ 'stack': {
+ '0,1': [
+ (default_tensor(2, dtype), default_tensor(2, dtype))
+ ]
+ },
+ '_split_v2': {
+ 'ary,indices_or_sections': [
+ (default_tensor(2, dtype), (2, 3))
+ ],
+ },
+ 'slice': {
+ 'data,begin,end': [
+ (default_tensor(2, dtype), (0, 1), (2, 4))
+ ],
+ },
+ 'space_to_depth': {
+ 'data,block_size': [
+ (default_tensor(4, dtype), 2)
+ ],
+ },
+ '_copy': {
+ 'data': [default_tensor(2, dtype)]
+ },
+ '_npi_transpose': {CFG_BASED_ON: 'transpose'},
+ '_npi_where': {CFG_BASED_ON: 'where'},
+ '_npx_reshape': {CFG_BASED_ON: 'Reshape'},
+
+ ###################################### Normalization
######################################
+
+ 'LayerNorm': {
+ 'data': [default_tensor(2, dtype)],
+ 'gamma': [TensorArg(common_weight_tensor)],
+ 'beta': [TensorArg(common_weight_tensor)],
+ },
+ 'BatchNorm': {
+ CFG_BASED_ON: 'LayerNorm',
+ 'moving_mean': [TensorArg(common_weight_tensor)],
+ 'moving_var': [
+ TensorArg(lambda shape, dtype: mx.nd.random.uniform(0, 1,
shape, dtype))
+ ],
+ },
+ '_contrib_BatchNormWithReLU': {CFG_BASED_ON: 'BatchNorm'},
+ 'LRN': {
+ 'data,nsize': [(default_tensor(2, dtype), 3)]
+ },
+
+ ######################################## Reduction
########################################
+
+ 'mean': {
+ '0': [default_tensor(2, dtype)],
+ 'axis': [0]
+ },
+ 'sum': {CFG_BASED_ON: 'mean'},
+ '_npi_mean': {CFG_BASED_ON: 'mean'},
+ '_npi_sum': {CFG_BASED_ON: 'mean'},
+
+ ######################################### Softmax
#########################################
+
+ 'softmax': {
+ 'data': [
+ default_tensor(2, dtype),
+ default_tensor(4, dtype)
+ ],
+ 'axis': [-1]
+ },
+ 'log_softmax': {CFG_BASED_ON: 'softmax'},
+ 'masked_softmax': {
+ CFG_BASED_ON: 'softmax',
+ 'mask': [
+ CfgBasedArg(
+ lambda cfg: mx.nd.random.randint(0, 2,
cfg['data'].shape).astype('bool')
+ )
+ ],
+ },
+
+ ################################### Activation / Unary
####################################
+
+ 'Activation': {
+ 'data': [default_tensor(2, dtype)],
+ 'act_type': ['sigmoid', 'log_sigmoid', 'relu', 'softrelu', 'tanh',
'mish']
+ },
+ 'LeakyReLU': {
+ 'data': [default_tensor(2, dtype)],
+ 'act_type': ['leaky', 'elu', 'gelu']
+ },
+ '_npi_exp': {
+ '0': [default_tensor(2, dtype)]
+ },
+ '_npi_sqrt': {
+ '0': [mx.nd.random.uniform(0, 8, DEFAULT_SHAPE*2, dtype)]
+ },
+ '_npi_square': {CFG_BASED_ON: '_npi_exp'},
+ '_npi_tanh': {CFG_BASED_ON: '_npi_exp'},
+
+ ######################################### Binary
##########################################
+
+ 'dot': {
+ '0,1': [
+ (default_tensor(3, dtype), default_tensor(3, dtype))
+ ],
+ },
+ 'batch_dot': {CFG_BASED_ON: 'dot'},
+ 'broadcast_add': {CFG_BASED_ON: 'dot'},
+ 'broadcast_div': {CFG_BASED_ON: 'dot'},
+ 'broadcast_mul': {CFG_BASED_ON: 'dot'},
+ 'broadcast_sub': {CFG_BASED_ON: 'dot'},
+ 'elemwise_add': {CFG_BASED_ON: 'dot'},
+ '_npi_dot': {CFG_BASED_ON: 'dot'},
+ '_npi_add': {CFG_BASED_ON: 'dot'},
+ '_npi_multiply': {CFG_BASED_ON: 'dot'},
+ '_npi_subtract': {CFG_BASED_ON: 'dot'},
+ '_npi_true_divide': {CFG_BASED_ON: 'dot'},
+
+ 'add_n': {CFG_BASED_ON: 'dot'}, # this is not binary, but can work as
binary
+
+ ######################################## Subgraph
#########################################
+
+ '_sg_onednn_conv': {
+ CFG_BASED_ON: 'Convolution',
+ CFG_SUBGRAPH: [SubgraphCfg('Convolution', 'ONEDNN')],
+ 'data,kernel': [
+ (default_tensor(4, dtype), (3, 3)),
+ (default_tensor(5, dtype), (3, 3, 3))
+ ]
+ },
+ '_sg_onednn_fully_connected': {
+ CFG_BASED_ON: 'FullyConnected',
+ CFG_SUBGRAPH: [SubgraphCfg('FullyConnected', 'ONEDNN')],
+ },
+ '_sg_onednn_batch_dot': {
+ CFG_BASED_ON: 'batch_dot',
+ CFG_SUBGRAPH: [SubgraphCfg('batch_dot', 'ONEDNN')],
+ },
+ '_sg_onednn_selfatt_qk': {
+ CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk', 'ONEDNN')],
+ 'queries_keys_values': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8),
dtype)],
+ 'heads': [2]
+ },
+ '_sg_onednn_selfatt_valatt': {
+ CFG_BASED_ON: '_sg_onednn_selfatt_qk',
+ CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_valatt', 'ONEDNN')],
+ 'attention': [CfgBasedArg(valatt_attention_tensor)]
+ }
+ }
+
+
+def product_dict(dict_of_lists):
+ keys = dict_of_lists.keys()
+ lists = dict_of_lists.values()
+ for scenario in product(*lists):
+ yield dict(zip(keys, scenario))
+
+
+def resolve_cfg_references(args_cfg, all_ops_cfgs):
+ if len(args_cfg) == 0:
+ return {}
+ args_cfg = args_cfg.copy()
+ base_op = args_cfg.pop(CFG_BASED_ON, None)
+ base_cfg = all_ops_cfgs.get(base_op, {})
+ result_cfg = resolve_cfg_references(base_cfg, all_ops_cfgs)
+ result_cfg.update(args_cfg)
+ return result_cfg
+
+
+def get_op_cfg_generator(op_names, dtype):
+ all_ops_cfgs = get_all_ops_cfgs(dtype)
+ for op_name in set(op_names):
+ args_cfgs = all_ops_cfgs[op_name]
+ args_cfgs = resolve_cfg_references(args_cfgs, all_ops_cfgs)
+ for args_scenario in product_dict(args_cfgs):
+ yield (op_name, args_scenario)
+
+
+def get_symblock_from_args_scenario(op_name, args_scenario):
+ args_scenario = args_scenario.copy()
+ args_scenario.pop(CFG_RTOL_ATOL, None) # not used here
+ subgraph_cfg = args_scenario.pop(CFG_SUBGRAPH, None)
+ if subgraph_cfg is None:
+ op_sym_fn = get_op_sym_fn(op_name)
+ else:
+ op_sym_fn = get_op_sym_fn(subgraph_cfg.base_op)
+
+
+ # split binded args
+ binded_args = [(k, v) for k, v in args_scenario.items() if ',' in k]
+ for arg_names, arg_cfgs in binded_args:
+ args_scenario.pop(arg_names)
+ arg_names = arg_names.replace(' ', '').split(',')
+ assert isinstance(arg_cfgs, tuple) and len(arg_cfgs) == len(arg_names)
+ for arg_name, arg_cfg in zip(arg_names, arg_cfgs):
+ assert arg_name not in args_scenario
+ args_scenario[arg_name] = arg_cfg
+
+ # generate cfg based args
+ for arg_name, arg_cfg in args_scenario.items():
+ if isinstance(arg_cfg, CfgBasedArg):
+ args_scenario[arg_name] = arg_cfg.gen_arg(args_scenario)
+
+ kw_args = {}
+ pos_args = {}
+ for arg_name, arg_cfg in args_scenario.items():
+ if isinstance(arg_cfg, (TensorArg, mx.nd.NDArray, mx.np.ndarray)):
+ arg_cfg = mx.sym.var(arg_name)
+ if _is_np_op(op_name):
+ arg_cfg = arg_cfg.as_np_ndarray()
+ if arg_name.isdigit():
+ pos_args[int(arg_name)] = arg_cfg
+ else:
+ kw_args[arg_name] = arg_cfg
+ pos_args = [pos_args[k] for k in sorted(pos_args.keys())]
+
+ sym = op_sym_fn(*pos_args, **kw_args)
+ if subgraph_cfg is not None:
+ if len(sym.list_outputs()) > 1:
+ sym = sym[0]
+ # add additional op (+1), so the graph pass can convert the tested op
+ sym = mx.sym.relu(sym).optimize_for(subgraph_cfg.backend)
+ assert op_name in sym.tojson()
+
+ args_with_shape, args_with_dtype = {}, {}
+ for arg_name, arg_cfg in args_scenario.items():
+ if isinstance(arg_cfg, (mx.nd.NDArray, mx.np.ndarray)):
+ args_with_shape[arg_name] = arg_cfg.shape
+ args_with_dtype[arg_name] = arg_cfg.dtype
+
+ infered_shapes_args, _, infered_shapes_auxs =
sym.infer_shape(**args_with_shape)
+ infered_shapes_args = dict(zip(sym.list_arguments(), infered_shapes_args))
+ infered_shapes_auxs = dict(zip(sym.list_auxiliary_states(),
infered_shapes_auxs))
+
+ infered_dtypes_args, _, infered_dtypes_auxs =
sym.infer_type(**args_with_dtype)
+ infered_dtypes_args = dict(zip(sym.list_arguments(), infered_dtypes_args))
+ infered_dtypes_auxs = dict(zip(sym.list_auxiliary_states(),
infered_dtypes_auxs))
+
+ symblock_input_data = {}
+ for arg_name in [*sym.list_arguments(), *sym.list_auxiliary_states()]:
+ tensor_cfg = args_scenario[arg_name]
+ if isinstance(tensor_cfg, TensorArg):
+ shape = infered_shapes_args.get(arg_name,
infered_shapes_auxs.get(arg_name, None))
+ dtype = infered_dtypes_args.get(arg_name,
infered_dtypes_auxs.get(arg_name, None))
+ tensor = tensor_cfg.gen_tensor(shape, dtype)
+ else:
+ tensor = tensor_cfg
+ symblock_input_data[arg_name] = tensor
+
+ symblock_input_syms = [mx.sym.var(name) for name in
symblock_input_data.keys()]
+ if _is_np_op(op_name):
+ symblock_input_syms = [var.as_np_ndarray() for var in
symblock_input_syms]
+ symblock = mx.gluon.SymbolBlock(sym, symblock_input_syms)
+ symblock.initialize()
+ assert len(symblock.collect_params()) == 0
+
+ return symblock, list(symblock_input_data.values())
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py
b/tests/python/dnnl/subgraphs/subgraph_common.py
index a23ba3b69c..f58b025dc2 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -90,8 +90,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
if k.find('_quantize') != -1:
assert v['out_type'] == out_type
if k.find(quantized_op_name) != -1:
- if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected")
- or quantized_op_name.startswith("quantized_sg_onednn_conv")) and
'enable_float_output' in v:
+ if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected")
+ or quantized_op_name.startswith("quantized_sg_onednn_conv")) and
'enabled_float_output' in v:
continue
assert 'min_calib_range' in v
assert 'max_calib_range' in v
diff --git a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
index b66ea44cde..51460c8cc0 100644
--- a/tests/python/dnnl/subgraphs/test_amp_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_amp_subgraph.py
@@ -35,33 +35,16 @@ AMP_SG_PASS_NAME = 'ONEDNN_AMP'
AMP_DTYPE = 'bfloat16'
-# Checks if amp (after the AMP_SG_PASS_NAME fuse) changes the name of tensors
for calibration
-def check_amp_with_quantization(net, data_example, quantized_nodes):
- net.optimize_for(data_example, backend=QUANTIZE_SG_PASS_NAME)
- symnet = net.export(None)[0]
- nodes = {n['name'] for n in json.loads(symnet.tojson())['nodes'] if n['op']
!= 'null'}
- quant_excluded_nodes = list(nodes - set(quantized_nodes))
-
- _, calib_tensors1 = mx.contrib.quantization._quantize_symbol(
- symnet, mx.current_context(), excluded_symbols=quant_excluded_nodes)
-
- lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE,
- excluded_sym_names=quantized_nodes,
cast_params_offline=True,
- device=mx.current_context())
- lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME)
- lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
- _, calib_tensors2 = mx.contrib.quantization._quantize_symbol(
- lp_symnet, mx.cpu(), excluded_symbols=quant_excluded_nodes)
- assert calib_tensors1 == calib_tensors2
-
def same_graph_structure(symnet_observed, symnet_expected, expected):
nodes_obs =
json.loads(symnet_observed.tojson(remove_amp_cast=False))['nodes']
nodes_exp =
json.loads(symnet_expected.tojson(remove_amp_cast=False))['nodes']
+ nodes_obs = [(node['op'], node['inputs']) for node in nodes_obs]
+ nodes_exp = [(node['op'], node['inputs']) for node in nodes_exp]
assert (len(nodes_obs) == len(nodes_exp)) == expected
for node_obs, node_exp in zip(nodes_obs, nodes_exp):
- if node_obs['op'] != node_exp['op'] or node_obs['inputs'] !=
node_exp['inputs']:
- assert expected == False
+ if node_obs != node_exp:
+ assert expected == False, '\n'.join([f'{n1} vs {n2}' for n1, n2 in
zip(nodes_obs, nodes_exp)])
break
@@ -87,9 +70,6 @@ def check_amp_fuse(net, data_example, expected_sym=None,
quantized_nodes=[], rto
lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
same_graph_structure(lp_symnet, expected_sym, True)
- # check amp with quantization
- check_amp_with_quantization(net, data_example, quantized_nodes)
-
@mx.util.use_np
def test_amp_fc():
@@ -221,7 +201,8 @@ def test_amp_fuse_with_branch():
out = self.fc1(x)
out1 = self.fc2(out)
out1 = nn.Activation('relu')(out1)
- out2 = mx.npx.softmax(out)
+ with nn.HybridBlock.OptConstraint.disable_amp():
+ out2 = mx.npx.softmax(out)
return out1, out2
net = TestNet()
diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
index de680ae352..0aa8396473 100644
--- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
@@ -311,7 +311,7 @@ def function_fc_add(data_shape, add_op, quantize_mode,
fc_out_add, flatten, relu
if quantize_mode is not None:
attrs['fc']['quantized'] = 'true'
if quantize_mode == 'smart':
- attrs['fc']['enable_float_output'] = 'true'
+ attrs['fc']['enabled_float_output'] = mx.nd.get_dtype_name(mx.np.float32)
num_hidden=10
net = FCWithSumExample(num_hidden, add_op, fc_out_add)
if flatten:
diff --git a/tests/python/dnnl/test_amp.py b/tests/python/dnnl/test_amp.py
index 73be9bb9a8..df0e2f7159 100644
--- a/tests/python/dnnl/test_amp.py
+++ b/tests/python/dnnl/test_amp.py
@@ -20,9 +20,18 @@ from pathlib import Path
curr_path = Path(__file__).resolve().parent
sys.path.insert(0, str(curr_path.parent))
+import pytest
import mxnet as mx
import amp.common as amp_common_tests
+from mxnet.test_utils import assert_almost_equal
+from mxnet.amp.lists.symbol_bf16 import (BF16_FUNCS, BF16_FP32_FUNCS,
WIDEST_TYPE_CASTS,
+ CONDITIONAL_FP32_FUNCS)
+from op_cfg import get_op_cfg_generator, get_symblock_from_args_scenario,
CFG_RTOL_ATOL
+
+
+ALL_BF16_OPS = BF16_FUNCS + BF16_FP32_FUNCS + WIDEST_TYPE_CASTS
+ALL_BF16_OPS += [op_name for op_name, attr_name, attr_vals in
CONDITIONAL_FP32_FUNCS]
AMP_DTYPE = 'bfloat16'
@@ -54,3 +63,60 @@ def test_bf16_fp32_ops_order_independence():
@mx.util.use_np
def test_bf16_test_node_excluding():
amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
+
+
+def get_param_name(param):
+ if isinstance(param, (mx.nd.NDArray, mx.np.ndarray)):
+ return 'Tensor' + str(param.shape)
+ if isinstance(param, (tuple, list)):
+ return str(type(param)(get_param_name(elem) for elem in param))
+ return str(param)
+
+
+def get_test_name(param):
+ if isinstance(param, str):
+ return f'"{param}" ' # op_name
+ if isinstance(param, dict):
+ elements = []
+ for args_names, args_cfgs in param.items():
+ if isinstance(args_cfgs, tuple):
+ binded_args = args_names.split(',')
+ for arg_name, arg_val in zip(binded_args, args_cfgs):
+ elements.append(f'"{arg_name}": {get_param_name(arg_val)}')
+ else:
+ arg_name, arg_val = args_names, args_cfgs
+ elements.append(f'"{arg_name}": {get_param_name(arg_val)}')
+ return ' ' + ', '.join(elements)
+ raise TypeError('Op configuration should only consist of its name (str)
and arg config (dict)')
+
+
[email protected](argnames=('op_name', 'args_scenario'),
+ argvalues=get_op_cfg_generator(ALL_BF16_OPS,
AMP_DTYPE),
+ ids=get_test_name)
+def test_bf16_op(op_name, args_scenario):
+ symblock, bf16_symblock_input_data =
get_symblock_from_args_scenario(op_name, args_scenario)
+ rtol, atol = args_scenario.get(CFG_RTOL_ATOL, (0.01, None))
+
+ fp32_symblock_input_data = []
+ for tensor in bf16_symblock_input_data:
+ if mx.nd.get_dtype_name(tensor.dtype) == 'bfloat16':
+ tensor = tensor.astype('float32')
+ fp32_symblock_input_data.append(tensor)
+
+ try:
+ bf16_outs = symblock(*bf16_symblock_input_data)
+ fp32_outs = symblock(*fp32_symblock_input_data)
+ mx.nd.waitall()
+ except mx.MXNetError as e:
+ pytest.fail(str(e))
+
+ if not isinstance(bf16_outs, (list, tuple)):
+ bf16_outs = [bf16_outs]
+ if not isinstance(fp32_outs, (list, tuple)):
+ fp32_outs = [fp32_outs]
+
+ assert any(mx.nd.get_dtype_name(tensor.dtype) == 'bfloat16'
+ for tensor in bf16_symblock_input_data + bf16_outs)
+ assert len(bf16_outs) == len(fp32_outs)
+ for bf16_out, fp32_out in zip(bf16_outs, fp32_outs):
+ assert_almost_equal(bf16_out.astype('float32'),
fp32_out.astype('float32'), rtol, atol)