This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b54b7f3  [REFACTOR] Refactor test_quantize.py to use Gluon API (#20227)
b54b7f3 is described below

commit b54b7f36c944b933f2aa94f224821548da5fcaf6
Author: bgawrych <[email protected]>
AuthorDate: Thu Jun 17 14:57:45 2021 +0200

    [REFACTOR] Refactor test_quantize.py to use Gluon API (#20227)
    
    * Refactor test_quantization to gluon API
    
    * review
    
    * Apply review suggestions
    
    * Skip flaky test
---
 .../quantization/mkldnn/mkldnn_quantized_conv.cc   |   2 -
 tests/python/quantization/test_quantization.py     | 873 ++++++++++++---------
 2 files changed, 486 insertions(+), 389 deletions(-)

diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc 
b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
index e1d6a80..e32b26b 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
@@ -38,8 +38,6 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& 
attrs,
                                        const std::vector<NDArray> &in_data,
                                        const std::vector<OpReqType> &req,
                                        const std::vector<NDArray> &out_data) {
-  CHECK_EQ(in_data[0].dtype(), mshadow::kUint8)
-    << "mkldnn_quantized_conv op only supports uint8 as input type";
   TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
   NDArray weight = in_data[conv::kWeight];
   ConvolutionParam param = nnvm::get<ConvolutionParam>(attrs.parsed);
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index df6e9c6..21866f0 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -29,14 +29,6 @@ import unittest
 import operator
 
 
-def initialize_block_params(block, initializer):
-    for name, param in 
block.collect_params('.*gamma|.*running_var|.*moving_var').items():
-        param.initialize(mx.init.Constant(1))
-    for name, param in 
block.collect_params('.*beta|.*bias|.*moving_mean|.*running_mean').items():
-        param.initialize(mx.init.Constant(0))
-    for name, param in block.collect_params('.*weight').items():
-        param.initialize(initializer)
-
 def collect_block_args_aux(block, sym):
   arg_params, aux_params = dict(), dict()
   for k, v in block.collect_params().items():
@@ -171,26 +163,28 @@ def test_requantize_int32_to_int8():
         assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
         assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))
 
-    def check_requantize_with_symbol(shape, min_calib_range=None, 
max_calib_range=None):
+    def check_requantize_with_gluon(shape, min_calib_range=None, 
max_calib_range=None):
         qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, 
shape=shape).astype('int32')
         min_range = mx.nd.array([-1010.0])
         max_range = mx.nd.array([1020.0])
-        sym_data = mx.sym.Variable('data')
-        sym_min_range = mx.sym.Variable('min_range')
-        sym_max_range = mx.sym.Variable('max_range')
-        if min_calib_range is None or max_calib_range is None:
-            requant = mx.sym.contrib.requantize(sym_data, sym_min_range, 
sym_max_range)
-            out = requant._bind(ctx=mx.current_context(),
-                               args={'data':qdata, 'min_range':min_range,
-                               'max_range':max_range})
-            qdata_int8, min_output, max_output = out.forward()
-        else:
-            requant = mx.sym.contrib.requantize(sym_data, sym_min_range, 
sym_max_range,
-                                                
min_calib_range=min_calib_range,
-                                                
max_calib_range=max_calib_range)
-            out = requant._bind(ctx=mx.current_context(), args={'data':qdata, 
'min_range':min_range,
-                               'max_range':max_range})
-            qdata_int8, min_output, max_output = out.forward()
+
+        class RequantizeBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, min_calib_range=None, max_calib_range=None, 
**kwargs):
+                super(RequantizeBlock, self).__init__(**kwargs)
+                self.min_calib_range = min_calib_range
+                self.max_calib_range = max_calib_range
+
+            def hybrid_forward(self, F, x, min_range, max_range):
+                if self.min_calib_range is not None and self.max_calib_range 
is not None:
+                    out = F.contrib.requantize(x, min_range, max_range,
+                                               
min_calib_range=self.min_calib_range,
+                                               
max_calib_range=self.max_calib_range)
+                else:
+                    out = F.contrib.requantize(x, min_range, max_range)
+                return out
+
+        requant = RequantizeBlock(min_calib_range, max_calib_range)  # 
m*_calib_ranges can be None
+        qdata_int8, min_output, max_output = requant(qdata, min_range, 
max_range)
 
         qdata_int8_np, min_output_np, max_output_np = 
requantize_baseline(qdata.asnumpy(), min_range.asscalar(),
                                                                           
max_range.asscalar(),
@@ -200,11 +194,11 @@ def test_requantize_int32_to_int8():
         assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
         assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))
 
-    # test with symbol API.
-    check_requantize_with_symbol((3, 4, 10, 10))
-    check_requantize_with_symbol((32, 3, 23, 23))
-    check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, 
max_calib_range=1040.0)
-    check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, 
max_calib_range=523.43)
+    # test with gluon API.
+    check_requantize_with_gluon((3, 4, 10, 10))
+    check_requantize_with_gluon((32, 3, 23, 23))
+    check_requantize_with_gluon((3, 4, 10, 10), min_calib_range=-1050.0, 
max_calib_range=1040.0)
+    check_requantize_with_gluon((32, 3, 23, 23), min_calib_range=-134.349, 
max_calib_range=523.43)
     # Test with nd array API
     check_requantize((3, 4, 10, 10))
     check_requantize((32, 3, 23, 23))
@@ -213,7 +207,7 @@ def test_requantize_int32_to_int8():
 
 
 def test_quantized_conv():
-    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, 
dilate, no_bias, qdtype):
+    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, 
dilate, use_bias, qdtype):
         if is_test_for_native_cpu():
             print('skipped testing quantized_conv for native cpu since it is 
not supported yet')
             return
@@ -229,69 +223,105 @@ def test_quantized_conv():
             return
 
         # run fp32 conv
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        conv = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
-                                  dilate=dilate, no_bias=no_bias, 
cudnn_off=False, name='conv')
-        arg_shapes, _, _ = conv.infer_shape(data=data_shape)
-        arg_names = conv.list_arguments()
-        conv_exe_fp32 = conv._simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        if len(data_shape) == 4:
+            convfp32 = mx.gluon.nn.Conv2D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        elif len(data_shape) == 5:
+            convfp32 = mx.gluon.nn.Conv3D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        else:
+            print('unsupported shape')
+            assert False
+
         if qdtype == 'uint8':
             data_low = 0.0
             data_high = 127.0
         else:
             data_low = -127.0
             data_high = 127.0
-        conv_exe_fp32.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
-                                                                       
shape=data_shape).astype('int32')
-        conv_exe_fp32.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                       
shape=arg_shapes[1]).astype('int32')
-        if not no_bias:
-            conv_exe_fp32.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           
shape=arg_shapes[2]).astype('int32')
-        output = conv_exe_fp32.forward()[0]
+
+        convfp32.initialize()
+        input_data = mx.nd.random.uniform(low=data_low,
+                                          high=data_high,
+                                          shape=data_shape
+                                         ).astype('int32').astype('float32')
+        convfp32(input_data) # initialize params
+        mx.nd.waitall()
+        fp32_params = convfp32.collect_params()
+        new_args = dict()
+        new_args['weight'] = mx.nd.random.uniform(low=-127.0,
+                                                  high=127.0,
+                                                  
shape=fp32_params['weight'].shape
+                                                 
).astype('int32').astype('float32')
+        if use_bias:
+           new_args['bias'] = mx.nd.random.uniform(low=-127.0,
+                                                   high=127.0,
+                                                   
shape=fp32_params['bias'].shape
+                                                  
).astype('int32').astype('float32')
+        convfp32.load_dict(new_args, cast_dtype=True, dtype_source='saved')
+
+        output = convfp32(input_data)
 
         # run quantized conv
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        qweight = mx.sym.Variable(name='qweight', dtype='int8')
-        min_data = mx.sym.Variable(name='min_data')
-        max_data = mx.sym.Variable(name='max_data')
-        min_weight = mx.sym.Variable(name='min_weight')
-        max_weight = mx.sym.Variable(name='max_weight')
-        quantized_conv = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
-                                                       max_data=max_data, 
min_weight=min_weight,
-                                                       max_weight=max_weight, 
kernel=kernel,
-                                                       num_filter=num_filter, 
pad=pad, stride=stride,
-                                                       dilate=dilate, 
no_bias=no_bias)
-        qarg_names = quantized_conv.list_arguments()
-        type_dict = None
-        if not no_bias:
-            type_dict = {qarg_names[2]: 'int8'}
-        conv_exe_int8 = quantized_conv._simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
-        conv_exe_int8.arg_dict[qarg_names[0]][:] = 
conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype)
-        conv_exe_int8.arg_dict[qarg_names[1]][:] = 
conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
+        class QuantConv(mx.gluon.nn.HybridBlock):
+            def __init__(self, channels, kernel_size, strides=(1, 1),
+                         padding=(0, 0), dilation=(1, 1), use_bias=True, 
**kwargs):
+                super(QuantConv, self).__init__(**kwargs)
+                self.use_bias = use_bias
+                self._kwargs = {'kernel': kernel_size, 'stride': strides, 
'dilate': dilation,
+                                'pad': padding, 'num_filter': channels, 
'no_bias': not use_bias, 'num_group': 1,
+                                'layout': 'NCHW'}
+
+                self.min_data = mx.gluon.Parameter('min_data', 
dtype='float32', allow_deferred_init=True)
+                self.max_data = mx.gluon.Parameter('max_data', 
dtype='float32', allow_deferred_init=True)
+
+                self.weight = mx.gluon.Parameter('weight', dtype='int8', 
allow_deferred_init=True)
+                self.min_weight = mx.gluon.Parameter('min_weight', 
dtype='float32', allow_deferred_init=True)
+                self.max_weight = mx.gluon.Parameter('max_weight', 
dtype='float32', allow_deferred_init=True)
+                
+                if use_bias:
+                    self.bias = mx.gluon.Parameter('bias', dtype='int8', 
allow_deferred_init=True)
+                    self.min_bias = mx.gluon.Parameter('min_bias', 
dtype='float32', allow_deferred_init=True)
+                    self.max_bias = mx.gluon.Parameter('max_bias', 
dtype='float32', allow_deferred_init=True)
+
+            def hybrid_forward(self, F, x, weight, bias=None, min_data=None, 
max_data=None,
+                               min_weight=None, max_weight=None, 
min_bias=None, max_bias=None):
+                out = F.contrib.quantized_conv(data=x, weight=weight, 
bias=bias, 
+                                               min_data=min_data, 
max_data=max_data,
+                                               min_weight=min_weight, 
max_weight=max_weight,
+                                               min_bias=min_bias, 
max_bias=max_bias,
+                                               **self._kwargs)
+                return out
+
+        convint8 = QuantConv(channels=num_filter, kernel_size=kernel, 
strides=stride,
+                             padding=pad, dilation=dilate, use_bias=use_bias)
+
         quantized_range = 127.0
-        if no_bias:
-            conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range
-            conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range
-        else:
-            conv_exe_int8.arg_dict[qarg_names[2]][:] = 
conv_exe_fp32.arg_dict[arg_names[2]].astype('int8')
-            conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range
-            conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range
-            conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range
-        qoutput, min_range, max_range = conv_exe_int8.forward()
-
-        if no_bias:
-            assert_almost_equal(output.asnumpy(), qoutput.asnumpy(), atol = 1)
-        else:
+        qargs = {
+            'weight': new_args['weight'].astype('int8'),
+            'min_data': mx.nd.array([-quantized_range]),
+            'max_data': mx.nd.array([quantized_range]),
+            'min_weight': mx.nd.array([-quantized_range]),
+            'max_weight': mx.nd.array([quantized_range])
+        }
+        if use_bias:
+            qargs.update({
+                'bias': new_args['bias'].astype('int8'),
+                'min_bias': mx.nd.array([-quantized_range]),
+                'max_bias': mx.nd.array([quantized_range]),
+            })
+
+        convint8.load_dict(qargs, cast_dtype=True, dtype_source='saved')
+
+        qoutput, min_range, max_range = convint8(input_data.astype(qdtype))
+
+        if use_bias:
             # with adding bias, accuracy loss should not be greater than one
             diff = mx.nd.abs(output - qoutput.astype(output.dtype))
             cond = mx.nd.lesser(2, diff).sum().asscalar()
             assert cond == 0
+        else:
+            assert_almost_equal(output.asnumpy(), qoutput.asnumpy(), atol = 1)
 
     for qdtype in ['int8', 'uint8']:
         check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 
1), True, qdtype)
@@ -314,11 +344,22 @@ def test_quantized_elemwise_add():
             print('skipped testing quantized_elemwise_add for gpu since it is 
not supported yet')
             return
 
-        dataA = mx.sym.Variable(name='dataA', shape=data_shape, 
dtype='float32')
-        dataB = mx.sym.Variable(name='dataB', shape=data_shape, 
dtype='float32')
-        elemwise_add_fp32 = mx.sym.elemwise_add(dataA, dataB)
-        arg_names = elemwise_add_fp32.list_arguments()
-        elemwise_add_fp32_exe = 
elemwise_add_fp32._simple_bind(ctx=mx.current_context(), grad_req='null')
+        class ElemwiseSumBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(ElemwiseSumBlock, self).__init__(**kwargs)
+
+            def hybrid_forward(self, F, dataA, dataB):
+                return F.elemwise_add(dataA, dataB)
+
+        class QuantElemwiseSumBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(QuantElemwiseSumBlock, self).__init__(**kwargs)
+
+            def hybrid_forward(self, F, dataA, dataB, dataA_min, dataA_max, 
dataB_min, dataB_max):
+                return F.contrib.quantized_elemwise_add(dataA, dataB, 
dataA_min, dataA_max, dataB_min, dataB_max)
+
+        elemwise_add_fp32 = ElemwiseSumBlock()
+
         if qtype == 'uint8':
             data_low = 0.0
             data_high = 255.0
@@ -326,31 +367,24 @@ def test_quantized_elemwise_add():
             data_low = -127.0
             data_high = 127.0
 
-        dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32')
-        dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32')
-        elemwise_add_fp32_exe.arg_dict[arg_names[0]][:] = dataA_val
-
-        elemwise_add_fp32_exe.arg_dict[arg_names[1]][:] = dataB_val
-
-        output = elemwise_add_fp32_exe.forward()[0]
-        qdataA = mx.sym.Variable(name='qdataA', shape=data_shape, dtype=qtype)
-        qdataB = mx.sym.Variable(name='qdataB', shape=data_shape, dtype=qtype)
-        min_dataA = mx.sym.Variable(name='min_dataA', dtype='float32')
-        max_dataA = mx.sym.Variable(name='max_dataA', dtype='float32')
-        min_dataB = mx.sym.Variable(name='min_dataB', dtype='float32')
-        max_dataB = mx.sym.Variable(name='max_dataB', dtype='float32')
-        quantized_elemwise_add = mx.sym.contrib.quantized_elemwise_add(qdataA, 
qdataB, min_dataA, max_dataA, min_dataB, max_dataB)
-        elemwise_add_int8_exe = 
quantized_elemwise_add._simple_bind(ctx=mx.current_context(), grad_req='null')
-        qarg_names = quantized_elemwise_add.list_arguments()
-        elemwise_add_int8_exe.arg_dict[qarg_names[0]][:] = 
elemwise_add_fp32_exe.arg_dict[arg_names[0]].astype(qtype)
-        elemwise_add_int8_exe.arg_dict[qarg_names[1]][:] = 
elemwise_add_fp32_exe.arg_dict[arg_names[1]].astype(qtype)
+        dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32').astype('float32')
+        dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32').astype('float32')
+
+        output = elemwise_add_fp32(dataA_val, dataB_val)
+
+        #run quantized
+        quantized_elemwise_add = QuantElemwiseSumBlock()
+        dataA_val_int8 = dataA_val.astype(qtype)
+        dataB_val_int8 = dataB_val.astype(qtype)
         quantized_range = 127.0
-        elemwise_add_int8_exe.arg_dict[qarg_names[2]][:] = data_low
-        elemwise_add_int8_exe.arg_dict[qarg_names[3]][:] = data_high
-        elemwise_add_int8_exe.arg_dict[qarg_names[4]][:] = data_low
-        elemwise_add_int8_exe.arg_dict[qarg_names[5]][:] = data_high
-        qoutput, min_range, max_range = elemwise_add_int8_exe.forward()
-        int8_rslt = qoutput.astype(output.dtype)*max_range/0x7fffffff
+        min_dataA = mx.nd.array([data_low])
+        max_dataA = mx.nd.array([data_high])
+        min_dataB = mx.nd.array([data_low])
+        max_dataB = mx.nd.array([data_high])
+        qoutput, min_range, max_range = quantized_elemwise_add(dataA_val_int8, 
dataB_val_int8,
+                                                               min_dataA, 
max_dataA,
+                                                               min_dataB, 
max_dataB)
+        int8_rslt = qoutput.astype(output.dtype) * max_range / 0x7fffffff
         diff = mx.nd.abs(output - int8_rslt)
         cond = mx.nd.lesser(2, diff).sum().asscalar()
         assert cond == 0
@@ -374,11 +408,21 @@ def test_quantized_elemwise_mul():
             print('skipped testing quantized_elemwise_mul for gpu since it is 
not supported yet')
             return
 
-        dataA = mx.sym.Variable(name='dataA', shape=data_shape, 
dtype='float32')
-        dataB = mx.sym.Variable(name='dataB', shape=data_shape, 
dtype='float32')
-        elemwise_mul_fp32 = mx.sym.elemwise_mul(dataA, dataB)
-        arg_names = elemwise_mul_fp32.list_arguments()
-        elemwise_mul_fp32_exe = 
elemwise_mul_fp32._simple_bind(ctx=mx.current_context(), grad_req='null')
+        class ElemwiseMulBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(ElemwiseMulBlock, self).__init__(**kwargs)
+
+            def hybrid_forward(self, F, dataA, dataB):
+                return F.elemwise_mul(dataA, dataB)
+
+        class QuantElemwiseMulBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(QuantElemwiseMulBlock, self).__init__(**kwargs)
+
+            def hybrid_forward(self, F, dataA, dataB, dataA_min, dataA_max, 
dataB_min, dataB_max):
+                return F.contrib.quantized_elemwise_mul(dataA, dataB, 
dataA_min, dataA_max, dataB_min, dataB_max)
+
+        elemwise_mul_fp32 = ElemwiseMulBlock()
         if qtype == 'uint8':
             data_low = 0.0
             data_high = 255.0
@@ -386,31 +430,22 @@ def test_quantized_elemwise_mul():
             data_low = -127.0
             data_high = 127.0
 
-        dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32')
-        dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32')
-        elemwise_mul_fp32_exe.arg_dict[arg_names[0]][:] = dataA_val
-
-        elemwise_mul_fp32_exe.arg_dict[arg_names[1]][:] = dataB_val
-
-        output = elemwise_mul_fp32_exe.forward()[0]
-
-        qdataA = mx.sym.Variable(name='qdataA', shape=data_shape, dtype=qtype)
-        qdataB = mx.sym.Variable(name='qdataB', shape=data_shape, dtype=qtype)
-        min_dataA = mx.sym.Variable(name='min_dataA', dtype='float32')
-        max_dataA = mx.sym.Variable(name='max_dataA', dtype='float32')
-        min_dataB = mx.sym.Variable(name='min_dataB', dtype='float32')
-        max_dataB = mx.sym.Variable(name='max_dataB', dtype='float32')
-        quantized_elemwise_mul = mx.sym.contrib.quantized_elemwise_mul(qdataA, 
qdataB, min_dataA, max_dataA, min_dataB, max_dataB)
-        elemwise_mul_int8_exe = 
quantized_elemwise_mul._simple_bind(ctx=mx.current_context(), grad_req='null')
-        qarg_names = quantized_elemwise_mul.list_arguments()
-        elemwise_mul_int8_exe.arg_dict[qarg_names[0]][:] = 
elemwise_mul_fp32_exe.arg_dict[arg_names[0]].astype(qtype)
-        elemwise_mul_int8_exe.arg_dict[qarg_names[1]][:] = 
elemwise_mul_fp32_exe.arg_dict[arg_names[1]].astype(qtype)
+        dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32').astype('float32')
+        dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32').astype('float32')
+
+        output = elemwise_mul_fp32(dataA_val, dataB_val)
+
+        quantized_elemwise_mul = QuantElemwiseMulBlock()
+        dataA_val_int8 = dataA_val.astype(qtype)
+        dataB_val_int8 = dataB_val.astype(qtype)
         quantized_range = 127.0
-        elemwise_mul_int8_exe.arg_dict[qarg_names[2]][:] = data_low
-        elemwise_mul_int8_exe.arg_dict[qarg_names[3]][:] = data_high
-        elemwise_mul_int8_exe.arg_dict[qarg_names[4]][:] = data_low
-        elemwise_mul_int8_exe.arg_dict[qarg_names[5]][:] = data_high
-        qoutput, min_range, max_range = elemwise_mul_int8_exe.forward()
+        min_dataA = mx.nd.array([data_low])
+        max_dataA = mx.nd.array([data_high])
+        min_dataB = mx.nd.array([data_low])
+        max_dataB = mx.nd.array([data_high])
+        qoutput, min_range, max_range = quantized_elemwise_mul(dataA_val_int8, 
dataB_val_int8,
+                                                               min_dataA, 
max_dataA,
+                                                               min_dataB, 
max_dataB)
 
         fp32_rslt = output.asnumpy()
         int8_rslt = qoutput.astype(output.dtype)
@@ -435,38 +470,55 @@ def test_quantized_pooling():
             print('skipped testing quantized_pooling for gpu 5d layout since 
it is not supported yet')
             return
 
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, 
stride=stride,
-                                      pool_type=pool_type, 
global_pool=global_pool, cudnn_off=False,
-                                      pooling_convention=convention)
-        arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape)
-        arg_names = pooling_fp32.list_arguments()
-        pooling_fp32_exe = pooling_fp32._simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        class PoolingBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, kernel=kernel, pad=pad, stride=stride,
+                         pool_type=pool_type, global_pool=global_pool, 
cudnn_off=False,
+                         pooling_convention=convention):
+                super(PoolingBlock, self).__init__()
+                self._kwargs = {'kernel': kernel, 'pad': pad, 'stride': stride,
+                                'pool_type': pool_type, 'global_pool': 
global_pool,
+                                'cudnn_off': False, 'pooling_convention': 
convention}
+
+            def hybrid_forward(self, F, data):
+                return F.Pooling(data, **self._kwargs)
+
+        class QuantPoolingBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, kernel=kernel, pad=pad, stride=stride,
+                         pool_type=pool_type, global_pool=global_pool,
+                         cudnn_off=False, pooling_convention=convention):
+                super(QuantPoolingBlock, self).__init__()
+
+                self._kwargs = {'kernel': kernel, 'pad': pad, 'stride': stride,
+                                'pool_type': pool_type, 'global_pool': 
global_pool, 'cudnn_off': False,
+                                'pooling_convention':convention}
+
+            def hybrid_forward(self, F, data, min_data, max_data):
+                return F.contrib.quantized_pooling(data, min_data, max_data, 
**self._kwargs)
+
+        pooling_fp32 = PoolingBlock()
         if qdtype == 'uint8':
             data_low = 0.0
             data_high = 127.0
         else:
             data_low = -127.0
             data_high = 127.0
-        pooling_fp32_exe.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
-                                                                            
shape=data_shape).astype('int32')
-        output = pooling_fp32_exe.forward()[0]
-
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        min_data = mx.sym.Variable(name='min_data')
-        max_data = mx.sym.Variable(name='max_data')
-        quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, 
min_data=min_data,
-                                                             
max_data=max_data, kernel=kernel,
-                                                             pad=pad, 
stride=stride, pool_type=pool_type,
-                                                             
global_pool=global_pool,
-                                                             
pooling_convention=convention)
-        pooling_int8_exe = 
quantized_pooling._simple_bind(ctx=mx.current_context(), grad_req='null')
-        qarg_names = quantized_pooling.list_arguments()
-        pooling_int8_exe.arg_dict[qarg_names[0]][:] = 
pooling_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
+
+        input_data = mx.nd.random.uniform(low=data_low,
+                                          high=data_high,
+                                          shape=data_shape
+                                         ).astype('int32').astype('float32')
+        output = pooling_fp32(input_data)
+
+        quantized_pooling = QuantPoolingBlock(kernel=kernel, pad=pad, 
stride=stride,
+                                              pool_type=pool_type, 
global_pool=global_pool,
+                                              pooling_convention=convention)
+
+        int8_input_data = input_data.astype(qdtype)
         quantized_range = 127.0
-        pooling_int8_exe.arg_dict[qarg_names[1]][:] = -quantized_range
-        pooling_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range
-        qoutput, min_range, max_range = pooling_int8_exe.forward()
+        min_data = mx.nd.array([-quantized_range])
+        max_data = mx.nd.array([quantized_range])
+
+        qoutput, min_range, max_range = quantized_pooling(int8_input_data, 
min_data, max_data)
 
         if pool_type == 'max':
             assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
@@ -496,7 +548,7 @@ def test_quantized_pooling():
 
 
 def test_quantized_fc():
-    def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, 
flatten=True):
+    def check_quantized_fc(data_shape, num_hidden, use_bias, qdtype, 
flatten=True):
         if is_test_for_native_cpu():
             hasMKL = False
             for key in os.environ.keys():
@@ -514,11 +566,6 @@ def test_quantized_fc():
         def maxabs(a, b):
             return mx.nd.maximum(mx.nd.abs(a), mx.nd.abs(b))
 
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, 
no_bias=no_bias, flatten=flatten)
-        arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape)
-        arg_names = fc_fp32.list_arguments()
-        fc_fp32_exe = fc_fp32._simple_bind(ctx=mx.current_context(), 
grad_req='null')
         int8_range = 127.0
         if qdtype == 'uint8':
             data_low = 0.0
@@ -529,23 +576,33 @@ def test_quantized_fc():
             data_high = 63.0
             quantized_range = 127.0
 
-        data = mx.nd.random.uniform(low=data_low, high=data_high,
-                                    shape=data_shape).astype('int32')
-        weight = mx.nd.random.uniform(low=data_low, high=data_high,
-                                      shape=arg_shapes[1]).astype('int32')
-        fc_fp32_exe.arg_dict[arg_names[0]][:] = data
-        fc_fp32_exe.arg_dict[arg_names[1]][:] = weight
-
+        data = mx.nd.random.uniform(low=data_low,
+                                    high=data_high,
+                                    shape=data_shape
+                                   ).astype('int32').astype('float32')
+        fc_fp32 = mx.gluon.nn.Dense(units=num_hidden, use_bias=use_bias, 
flatten=flatten)
+        fc_fp32.initialize()
+        fc_fp32(data)
+        mx.nd.waitall()
+        fp32_params = fc_fp32.collect_params()
+
+        new_args = dict()
+        new_args['weight'] = mx.nd.random.uniform(low=data_low,
+                                                  high=data_high,
+                                                  
shape=fp32_params['weight'].shape
+                                                 
).astype('int32').astype('float32')
         data_min = mx.nd.min(data).astype('float32')
         data_max = mx.nd.max(data).astype('float32')
-        weight_min = mx.nd.min(weight).astype('float32')
-        weight_max = mx.nd.max(weight).astype('float32')
+        weight_min = mx.nd.min(new_args['weight']).astype('float32')
+        weight_max = mx.nd.max(new_args['weight']).astype('float32')
         data_range = maxabs(data_min, data_max)
         weight_range = maxabs(weight_min, weight_max)
 
-        if not no_bias:
-            bias = mx.nd.random.uniform(low=data_low, high=data_high,
-                                        shape=arg_shapes[2]).astype('int32')
+        if use_bias:
+            bias = mx.nd.random.uniform(low=data_low,
+                                        high=data_high,
+                                        shape=fp32_params['bias'].shape
+                                       ).astype('int32').astype('float32')
             bias_min = mx.nd.min(bias).astype('float32')
             bias_max = mx.nd.max(bias).astype('float32')
             bias_range = maxabs(bias_min, bias_max)
@@ -555,57 +612,79 @@ def test_quantized_fc():
             weight_scale = int8_range / weight_range
             bias_int32_rescale = data_scale * weight_scale / bias_scale
             new_bias = mx.nd.cast(bias, dtype='float32') * bias_int32_rescale
-            fc_fp32_exe.arg_dict[arg_names[2]][:] = new_bias.astype('int32')
-
-        output = fc_fp32_exe.forward()[0]
-
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, 
num_hidden=num_hidden,
-                                                           no_bias=no_bias, 
flatten=flatten)
-        qarg_names = fc_int8.list_arguments()
-        type_dict = {qarg_names[1]: 'int8'}
-        if not no_bias:
-            type_dict.update({qarg_names[2]: 'int8'})
-        fc_int8_exe = fc_int8._simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
-        fc_int8_exe.arg_dict[qarg_names[0]][:] = 
fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
-        fc_int8_exe.arg_dict[qarg_names[1]][:] = 
fc_fp32_exe.arg_dict[arg_names[1]].astype('int8')
-        if no_bias:
-            fc_int8_exe.arg_dict[qarg_names[2]][:] = -data_range
-            fc_int8_exe.arg_dict[qarg_names[3]][:] = data_range
-            fc_int8_exe.arg_dict[qarg_names[4]][:] = -weight_range
-            fc_int8_exe.arg_dict[qarg_names[5]][:] = weight_range
-        else:
-            fc_int8_exe.arg_dict[qarg_names[2]][:] = bias.astype('int8')
-            fc_int8_exe.arg_dict[qarg_names[3]][:] = -data_range
-            fc_int8_exe.arg_dict[qarg_names[4]][:] = data_range
-            fc_int8_exe.arg_dict[qarg_names[5]][:] = -weight_range
-            fc_int8_exe.arg_dict[qarg_names[6]][:] = weight_range
-            fc_int8_exe.arg_dict[qarg_names[7]][:] = -bias_range
-            fc_int8_exe.arg_dict[qarg_names[8]][:] = bias_range
-        qoutput, min_range, max_range = fc_int8_exe.forward()
-
-        if no_bias:
-            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-        else:
+            new_args['bias'] = new_bias.astype('int32').astype('float32')
+
+        fc_fp32.load_dict(new_args, cast_dtype=True, dtype_source='saved')
+        output = fc_fp32(data)
+
+        class QuantFC(mx.gluon.nn.HybridBlock):
+            def __init__(self, num_hidden, use_bias, flatten, **kwargs):
+                super(QuantFC, self).__init__(**kwargs)
+                self.use_bias = use_bias
+                self._kwargs = {'num_hidden': num_hidden, 'no_bias': not 
use_bias, 'flatten': flatten}
+
+                self.min_data = mx.gluon.Parameter('min_data', 
dtype='float32', allow_deferred_init=True)
+                self.max_data = mx.gluon.Parameter('max_data', 
dtype='float32', allow_deferred_init=True)
+
+                self.weight = mx.gluon.Parameter('weight', dtype='int8', 
allow_deferred_init=True)
+                self.min_weight = mx.gluon.Parameter('min_weight', 
dtype='float32', allow_deferred_init=True)
+                self.max_weight = mx.gluon.Parameter('max_weight', 
dtype='float32', allow_deferred_init=True)
+
+                if use_bias:
+                    self.bias = mx.gluon.Parameter('bias', dtype='int8', 
allow_deferred_init=True)
+                    self.min_bias = mx.gluon.Parameter('min_bias', 
dtype='float32', allow_deferred_init=True)
+                    self.max_bias = mx.gluon.Parameter('max_bias', 
dtype='float32', allow_deferred_init=True)
+
+            def hybrid_forward(self, F, x, weight, bias=None, min_data=None, 
max_data=None,
+                               min_weight=None, max_weight=None, 
min_bias=None, max_bias=None):
+                out = F.contrib.quantized_fully_connected(data=x, 
weight=weight, bias=bias, 
+                                                          min_data=min_data, 
max_data=max_data,
+                                                          
min_weight=min_weight, max_weight=max_weight,
+                                                          min_bias=min_bias, 
max_bias=max_bias,
+                                                          **self._kwargs)
+                return out
+
+        fc_int8 = QuantFC(num_hidden=num_hidden, use_bias=use_bias, 
flatten=flatten)
+        qargs = {
+            'weight': new_args['weight'].astype('int8'),
+            'min_data': mx.nd.array(-data_range),
+            'max_data': mx.nd.array(data_range),
+            'min_weight': mx.nd.array(-weight_range),
+            'max_weight': mx.nd.array(weight_range)
+        }
+        if use_bias:
+            qargs.update({
+                'bias': bias.astype('int8'),
+                'min_bias': mx.nd.array(-bias_range),
+                'max_bias': mx.nd.array(bias_range),
+            })
+
+        fc_int8.load_dict(qargs, cast_dtype=True, dtype_source='saved')
+
+        qoutput, min_range, max_range = fc_int8(data.astype(qdtype))
+
+        if use_bias:
             # with adding bias, accuracy loss should not be greater than one
             diff = mx.nd.abs(output - qoutput.astype(output.dtype))
             cond = mx.nd.lesser(2, diff).sum().asscalar()
             assert cond == 0
+        else:
+            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
 
     for qdtype in ['int8', 'uint8']:
         if is_test_for_mkldnn():
-            check_quantized_fc((32, 512, 2), 100, True, qdtype, flatten=False)
             check_quantized_fc((32, 512, 2), 100, False, qdtype, flatten=False)
-            check_quantized_fc((32, 512, 2, 2), 100, True, qdtype, 
flatten=False)
+            check_quantized_fc((32, 512, 2), 100, True, qdtype, flatten=False)
             check_quantized_fc((32, 512, 2, 2), 100, False, qdtype, 
flatten=False)
-        check_quantized_fc((32, 512, 2, 2), 100, True, qdtype)
-        check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
+            check_quantized_fc((32, 512, 2, 2), 100, True, qdtype, 
flatten=False)
         check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
         check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
-        check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
-        check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
+        check_quantized_fc((32, 512, 2, 2), 100, True, qdtype)
+        check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
         check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
         check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)
+        check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
+        check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
 
 
 def test_quantized_embedding():
@@ -617,34 +696,57 @@ def test_quantized_embedding():
         def maxabs(a, b):
             return mx.nd.maximum(mx.nd.abs(a), mx.nd.abs(b))
 
-        data0 = mx.sym.Variable(name='data', shape=data_shape, dtype='int32')
-        embedding_fp32 = mx.sym.Embedding(data=data0, input_dim=input_dim, 
output_dim=output_dim)
-        arg_shapes, _, _ = embedding_fp32.infer_shape(data=data_shape)
-        arg_names = embedding_fp32.list_arguments()
-        embedding_fp32_exe = 
embedding_fp32._simple_bind(ctx=mx.current_context(), grad_req='null')
+        data = mx.nd.random.uniform(low=0,
+                                    high=input_dim,
+                                    shape=data_shape
+                                   ).astype('int32').astype('float32')
+        embedding_fp32 = mx.gluon.nn.Embedding(input_dim=input_dim, 
output_dim=output_dim)
+        embedding_fp32.initialize()
+        embedding_fp32(data)
+        mx.nd.waitall()
+        fp32_params = embedding_fp32.collect_params()
         int8_range = 127.0
-        data = mx.nd.random.uniform(low=0, high=input_dim,
-                                      shape=arg_shapes[0]).astype('int32')
-        weight = mx.nd.random.uniform(low=-int8_range, high=int8_range,
-                                      shape=arg_shapes[1]).astype('int32')
-        embedding_fp32_exe.arg_dict[arg_names[0]][:] = data
-        embedding_fp32_exe.arg_dict[arg_names[1]][:] = weight
+        new_params = dict()
+        weight = mx.nd.random.uniform(low=-int8_range,
+                                      high=int8_range,
+                                      shape=fp32_params['weight'].shape
+                                     ).astype('int32').astype('float32')
+        new_params['weight'] = weight
+        embedding_fp32.load_dict(new_params, cast_dtype=True, 
dtype_source='saved')
+
+        output = embedding_fp32(data)
 
         weight_min = mx.nd.min(weight).astype('float32')
         weight_max = mx.nd.max(weight).astype('float32')
         weight_range = maxabs(weight_min, weight_max)
 
-        output = embedding_fp32_exe.forward()[0]
+        class QuantEmbedding(mx.gluon.nn.HybridBlock):
+            def __init__(self, input_dim=input_dim, output_dim=output_dim, 
**kwargs):
+                super(QuantEmbedding, self).__init__(**kwargs)
+                self._kwargs = {'input_dim': input_dim, 'output_dim': 
output_dim}
+
+                self.weight = mx.gluon.Parameter('weight', dtype='float32', 
allow_deferred_init=True)
+                self.min_weight = mx.gluon.Parameter('min_weight', 
dtype='float32', allow_deferred_init=True)
+                self.max_weight = mx.gluon.Parameter('max_weight', 
dtype='float32', allow_deferred_init=True)
+
+            def hybrid_forward(self, F, x, weight, min_weight=None, 
max_weight=None):
+                out = F.contrib.quantized_embedding(data=x, weight=weight,
+                                                    min_weight=min_weight,
+                                                    max_weight=max_weight,
+                                                    **self._kwargs)
+                return out
+
+        embedding_int8 = QuantEmbedding(input_dim=input_dim, 
output_dim=output_dim)
+        qargs = {
+            'weight': weight.astype('int8'),
+            'min_weight': mx.nd.array(-weight_range),
+            'max_weight': mx.nd.array(weight_range)
+        }
+
+        embedding_int8.load_dict(qargs, cast_dtype=True, dtype_source='saved')
+
+        qoutput, min_range, max_range = embedding_int8(data)
 
-        embedding_int8 = mx.sym.contrib.quantized_embedding(data=data0, 
input_dim=input_dim, output_dim=output_dim)
-        qarg_names = embedding_int8.list_arguments()
-        type_dict = {qarg_names[1]: 'int8'}
-        embedding_int8_exe = 
embedding_int8._simple_bind(ctx=mx.current_context(), type_dict=type_dict, 
grad_req='null')
-        embedding_int8_exe.arg_dict[qarg_names[0]][:] = 
embedding_fp32_exe.arg_dict[arg_names[0]]
-        embedding_int8_exe.arg_dict[qarg_names[1]][:] = 
embedding_fp32_exe.arg_dict[arg_names[1]].astype('int8')
-        embedding_int8_exe.arg_dict[qarg_names[2]][:] = -weight_range
-        embedding_int8_exe.arg_dict[qarg_names[3]][:] = weight_range
-        qoutput, min_range, max_range = embedding_int8_exe.forward()
 
         assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
 
@@ -691,11 +793,9 @@ def test_quantized_act():
         elif is_test_for_gpu():
             print('skipped testing quantized_act for gpu since it is not 
supported yet')
             return
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        act_fp32 = mx.sym.Activation(data=data, act_type='relu', name='relu')
-        arg_shapes, _, _ = act_fp32.infer_shape(data=data_shape)
-        arg_names = act_fp32.list_arguments()
-        act_fp32_exe = act_fp32._simple_bind(ctx=mx.current_context(), 
grad_req='null')
+
+        act_fp32 = mx.gluon.nn.Activation(activation='relu')
+
         if qdtype == 'uint8':
             data_low = 0.0
             data_high = 127.0
@@ -703,23 +803,27 @@ def test_quantized_act():
             data_low = -127.0
             data_high = 127.0
 
-        act_fp32_exe.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low,
-                                                high=data_high, 
shape=data_shape).astype(qdtype)
-        output = act_fp32_exe.forward()[0]
+        data = mx.nd.random.uniform(low=data_low,
+                                    high=data_high,
+                                    shape=data_shape
+                                   ).astype(qdtype).astype('float32')
+        output = act_fp32(data)
 
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        min_data = mx.sym.Variable(name='min_data')
-        max_data = mx.sym.Variable(name='max_data')
-        quantized_act = mx.sym.contrib.quantized_act(data=qdata, 
min_data=min_data, max_data=max_data, act_type='relu')
-        act_int8_exe = quantized_act._simple_bind(ctx=mx.current_context(), 
grad_req='null')
-        qarg_names = quantized_act.list_arguments()
+        class QuantActivation(mx.gluon.nn.HybridBlock):
+            def __init__(self, activation, **kwargs):
+                super(QuantActivation, self).__init__(**kwargs)
+                self._kwargs = {'act_type': activation}
 
-        act_int8_exe.arg_dict[qarg_names[0]][:] = 
act_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
-        quantized_range_min = 
mx.nd.min(act_int8_exe.arg_dict[qarg_names[0]][:])
-        quantized_range_max = 
mx.nd.max(act_int8_exe.arg_dict[qarg_names[0]][:])
-        act_int8_exe.arg_dict[qarg_names[1]][:] = 
quantized_range_min.astype(qdtype)
-        act_int8_exe.arg_dict[qarg_names[2]][:] = 
quantized_range_max.astype(qdtype)
-        qoutput, min_range, max_range = act_int8_exe.forward()
+            def hybrid_forward(self, F, x, min_data, max_data):
+                out = F.contrib.quantized_act(data=x, min_data=min_data, 
max_data=max_data, **self._kwargs)
+                return out
+
+        quantized_act = QuantActivation(activation='relu')
+
+        qdata = data.astype(qdtype)
+        quantized_range_min = mx.nd.min(data).astype('float32')
+        quantized_range_max = mx.nd.max(data).astype('float32')
+        qoutput, min_range, max_range = quantized_act(qdata, 
quantized_range_min, quantized_range_max)
 
         assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
         assert_almost_equal(min_range.asscalar(), 
quantized_range_min.asscalar())
@@ -760,48 +864,38 @@ def test_quantized_bn():
             data_high = 127.0
 
         # run fp32 bn
-        data_sym = mx.sym.Variable(name='data', shape=data_shape, 
dtype='float32')
-        bn_fp32 = mx.sym.BatchNorm(data=data_sym, name='bn', 
use_global_stats=True, fix_gamma=False)
-        arg_shapes, out_shapes, aux_shapes = 
bn_fp32.infer_shape(data=data_shape)
-        arg_names = bn_fp32.list_arguments()
-        aux_names = bn_fp32.list_auxiliary_states()
+        bn_fp32 = mx.gluon.nn.BatchNorm(use_global_stats=True, scale=True)
         data = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape)
-        gamma = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=arg_shapes[1])
-        beta = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=arg_shapes[2])
-        moving_mean, moving_var = get_mean_var(data)
-
-        bn_fp32_exe = bn_fp32._simple_bind(ctx=mx.current_context(), 
grad_req='null')
-        bn_fp32_exe.arg_dict[arg_names[0]][:] = data
-        bn_fp32_exe.arg_dict[arg_names[1]][:] = gamma
-        bn_fp32_exe.arg_dict[arg_names[2]][:] = beta
-        bn_fp32_exe.aux_dict[aux_names[0]][:] = moving_mean
-        bn_fp32_exe.aux_dict[aux_names[1]][:] = moving_var
-
-        output = bn_fp32_exe.forward()[0]
+        bn_fp32.initialize()
+        bn_fp32.hybridize()
+        bn_fp32(data)
+        fp32_params = bn_fp32.collect_params()
+        
+        data = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape)
+        gamma = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=fp32_params['gamma'].shape)
+        beta = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=fp32_params['beta'].shape)
+        running_mean, running_var = get_mean_var(data)
+        new_params = {
+            'gamma':gamma,
+            'beta':beta,
+            'running_mean': running_mean,
+            'running_var': running_var
+        }
+
+        bn_fp32.load_dict(new_params)
+        output = bn_fp32(data)
 
         # generate int8 bn from fp32 bn
-        arg_params = dict()
-        for k, v in bn_fp32_exe.arg_dict.items():
-            if 'data' in k or 'softmax_label' in k:
-                continue
-            arg_params[k] = v
-
         calib_data = mx.gluon.data.DataLoader(data, batch_size=data_shape[0])
-        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=bn_fp32,
-                                                                         
arg_params=arg_params,
-                                                                         
aux_params=bn_fp32_exe.aux_dict,
-                                                                         
ctx=mx.current_context(),
-                                                                         
quantized_dtype=qdtype,
-                                                                         
quantize_mode='full',
-                                                                         
calib_mode='naive',
-                                                                         
calib_data=calib_data,
-                                                                         
num_calib_batches=1)
+        quant_bn = mx.contrib.quant.quantize_net(bn_fp32,
+                                                 quantized_dtype=qdtype,
+                                                 quantize_mode='full',
+                                                 calib_data=calib_data,
+                                                 calib_mode='naive',
+                                                 num_calib_batches=1,
+                                                 ctx=mx.current_context())
 
-        sym_block = mx.gluon.SymbolBlock(outputs=qsym, inputs=data_sym)
-        params = qarg_params
-        params.update(qaux_params)
-        sym_block.load_dict(params)
-        output_int8_to_fp32 = sym_block.forward(data)
+        output_int8_to_fp32 = quant_bn(data)
 
         assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), 
rtol=1e-1, atol=8)
 
@@ -837,61 +931,57 @@ def test_quantize_params():
             assert name.find('quantize') != -1
 
 
-def get_fp32_sym():
-    data = mx.sym.Variable('data')
-    conv = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv')
-    bn = mx.sym.BatchNorm(data=conv, eps=2e-05, fix_gamma=False, momentum=0.9, 
use_global_stats=False, name='bn')
-    act = mx.sym.Activation(data=bn, act_type='relu', name='relu')
-    pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
-    fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc')
-    sym = mx.sym.softmax(fc, name='softmax')
-    return sym
-
-def get_fp32_residual():
-    data = mx.sym.Variable('data')
-    conv0 = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), 
pad=(0,0),
-                               no_bias=True, name='conv0')
-    bn = mx.sym.BatchNorm(data=conv0, fix_gamma=False, eps=2e-5, momentum=0.9, 
name='bn')
-    sum0 = mx.sym.elemwise_add(bn, data, name='sum0')
-    act0 = mx.sym.Activation(data=sum0, act_type='relu', name='relu0')
-    pool0 = mx.sym.Pooling(act0, kernel=(4, 4), pool_type='avg', name='pool0')
-    conv1 = mx.sym.Convolution(data=pool0, num_filter=4, kernel=(1,1), 
pad=(0,0),
-                               no_bias=False, name='conv1')
-    act1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1')
-    pool1 = mx.sym.Pooling(act1, kernel=(4, 4), pool_type='avg', name='pool1')
-    fc = mx.sym.FullyConnected(pool1, num_hidden=10, flatten=True, name='fc')
-    sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, 
multi_output=False,
-                               out_grad=False, preserve_shape=False, 
use_ignore=False, name='softmax')
-    return sym
-
-def get_fp32_sym_with_multiple_outputs(length=1):
-    data = mx.sym.Variable('data')
-    inputs = list(mx.sym.split(data, axis=0, num_outputs=length, 
squeeze_axis=1, name='split'))
-
-    _conv_outs = []
-    for i in range(length):
-        _conv_outs.append(mx.sym.Convolution(data=inputs[i], kernel=(1, 1), 
num_filter=16, name='conv_{0}'.format(i)))
-    conv_out = [mx.sym.expand_dims(i, axis=0) for i in _conv_outs]
-    conv_out = mx.sym.Concat(*conv_out, dim=0, name='concat')
-    reshape_out = mx.sym.reshape(data=conv_out, shape=((length, -1)), 
name='reshape')
-    fc_out = mx.sym.FullyConnected(reshape_out, num_hidden=10, flatten=True, 
name='fc')
-    sym = mx.sym.softmax(fc_out, name='softmax')
-    return sym
-
-def get_fp32_sym_with_multiple_inputs():
-    data1 = mx.symbol.Variable('data1', shape=(64, 4, 10, 10), dtype='float32')
-    data2 = mx.symbol.Variable('data2', shape=(64, 4, 10, 10), dtype='float32')
-    weight1 = mx.symbol.Variable('conv1_weight', dtype='float32')
-    weight2 = mx.symbol.Variable('conv2_weight', dtype='float32')
-    conv1 = mx.symbol.Convolution(data=data1, weight=weight1, name='conv1', 
num_filter=64,
-                            kernel=(1, 1), stride=(1, 1), no_bias=True)
-    bn1 = mx.symbol.BatchNorm(data=conv1, name="bn1")
-    conv2 = mx.symbol.Convolution(data=data2, weight=weight2, name='conv2', 
num_filter=64,
-                            kernel=(1, 1), stride=(1, 1), no_bias=True)
-    bn2 = mx.symbol.BatchNorm(data=conv2, name="bn2")
-    sum = bn2 + bn1
-    return sum
-
+class FP32Net(mx.gluon.nn.HybridBlock):
+    def __init__(self, **kwargs):
+        super(FP32Net, self).__init__(**kwargs)
+        self.conv = mx.gluon.nn.Conv2D(channels=16, kernel_size=(1,1))
+        self.bn = mx.gluon.nn.BatchNorm(epsilon=2e-05, scale=True, 
momentum=0.9, use_global_stats=False)
+        self.act = mx.gluon.nn.Activation(activation='relu')
+        self.pool = mx.gluon.nn.AvgPool2D(pool_size=(4,4))
+        self.fc = mx.gluon.nn.Dense(units=10, flatten=True)
+
+    def hybrid_forward(self, F, x):
+        out = self.conv(x)
+        out = self.bn(out)
+        out = self.act(out)
+        out = self.pool(out)
+        out = self.fc(out)
+        return F.softmax(out)
+
+
+class FP32MultipleOutputs(mx.gluon.nn.HybridBlock):
+    def __init__(self, length, **kwargs):
+        super(FP32MultipleOutputs, self).__init__(**kwargs)
+        self.length = length
+        self.convs = mx.gluon.nn.Conv2D(channels=16, kernel_size=(1,1))
+        self.fc = mx.gluon.nn.Dense(units=10, flatten=True)
+
+    def hybrid_forward(self, F, x):
+        res = F.SliceChannel(x, num_outputs=self.length,
+                             axis=1, squeeze_axis=1)
+        out = []
+        for i in range(self.length):
+            out.append(self.convs(res[i]))
+            out[i] = F.expand_dims(out[i], axis=0)
+        out = F.concat(*out)
+        out = F.reshape(out, shape=((self.length, -1)))
+        out = self.fc(out)
+        return F.softmax(out)
+
+class FP32MultipleInputs(mx.gluon.nn.HybridBlock):
+    def __init__(self, **kwargs):
+        super(FP32MultipleInputs, self).__init__(**kwargs)
+        self.conv1 = mx.gluon.nn.Conv2D(channels=64, kernel_size=(1,1), 
use_bias=False)
+        self.bn1 = mx.gluon.nn.BatchNorm()
+        self.conv2 = mx.gluon.nn.Conv2D(channels=64, kernel_size=(1,1), 
use_bias=False)
+        self.bn2 = mx.gluon.nn.BatchNorm()
+
+    def hybrid_forward(self, F, data0, data1):
+        out0 = self.conv1(data0)
+        out0 = self.bn1(out0)
+        out1 = self.conv2(data1)
+        out1 = self.bn2(out1)
+        return out1 + out0
 
 @xfail_when_nonstandard_decimal_separator
 def test_quantize_model():
@@ -945,23 +1035,24 @@ def test_quantize_model():
             print('skipped testing quantize_model for gpu uint8 since it is 
not supported yet')
             return
 
-        sym = get_fp32_sym()
+        standard_net = FP32Net()
+        standard_net.initialize()
         batch_size = 4
         data_shape = (batch_size, 4, 10, 10)
 
         length = batch_size  # specify num of outputs from split op
-        msym = get_fp32_sym_with_multiple_outputs(length)
-        msym_data_shape = (length, 4, 4, 10, 10)
+        multi_out_net = FP32MultipleOutputs(length)
+        multi_out_net.initialize()
+        multi_out_data_shape = (length, 4, 4, 10, 10)
 
-        for s, dshape in zip((sym, msym), (data_shape, msym_data_shape)):
-            data = mx.sym.Variable('data')
-            sym_block = mx.gluon.SymbolBlock(outputs=s, inputs=data)
-            initialize_block_params(sym_block, mx.init.One())
+        for net, dshape in zip((standard_net, multi_out_net), (data_shape, 
multi_out_data_shape)):
             data = mx.nd.random.uniform(low=0, high=1, shape=dshape)
-            sym_block.forward(data)
-            arg_params, aux_params = collect_block_args_aux(sym_block, s)
+            net.hybridize()
+            net(data)
+            sym, _ = net.export(None)
+            arg_params, aux_params = collect_block_args_aux(net, sym)
 
-            qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=s,
+            qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
                                                                              
arg_params=arg_params,
                                                                              
aux_params=aux_params,
                                                                              
ctx=mx.current_context(),
@@ -973,7 +1064,7 @@ def test_quantize_model():
 
             calib_data = mx.nd.random.uniform(shape=dshape)
             calib_data = mx.gluon.data.DataLoader(calib_data, 
batch_size=batch_size)
-            qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=s,
+            qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
                                                                              
arg_params=arg_params,
                                                                              
aux_params=aux_params,
                                                                              
ctx=mx.current_context(),
@@ -991,24 +1082,23 @@ def test_quantize_model():
         if skip_not_supported():
             return
 
-        sym = get_fp32_sym_with_multiple_inputs()
+        net = FP32MultipleInputs()
+        net.initialize()
+        net.hybridize()
         dshape = (64, 4, 10, 10)
-        data1 = mx.sym.Variable('data1')
-        data2 = mx.sym.Variable('data2')
-        sym_block = mx.gluon.SymbolBlock(outputs=sym, inputs=[data1, data2])
-        initialize_block_params(sym_block, mx.init.One())
         data = [mx.nd.random.uniform(low=0, high=1, shape=dshape),
                 mx.nd.random.uniform(low=0, high=1, shape=dshape)]
-        sym_block.forward(*data)
-        arg_params, aux_params = collect_block_args_aux(sym_block, sym)
+        net(*data)
+        sym, _ = net.export(None)
+        arg_params, aux_params = collect_block_args_aux(net, sym)
 
         qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
-                                                                            
arg_params=arg_params,
-                                                                            
aux_params=aux_params,
-                                                                            
ctx=mx.current_context(),
-                                                                            
quantized_dtype=qdtype,
-                                                                            
calib_mode='none',
-                                                                            
quantize_mode='full')
+                                                                         
arg_params=arg_params,
+                                                                         
aux_params=aux_params,
+                                                                         
ctx=mx.current_context(),
+                                                                         
quantized_dtype=qdtype,
+                                                                         
calib_mode='none',
+                                                                         
quantize_mode='full')
         check_params(arg_params, qarg_params, qsym)
         check_params(aux_params, qaux_params)
 
@@ -1022,7 +1112,7 @@ def test_quantize_model():
                                                                          
quantized_dtype=qdtype,
                                                                          
calib_mode='naive',
                                                                          
calib_data=calib_data,
-                                                                         
data_names=["data1","data2"],
+                                                                         
data_names=["data0","data1"],
                                                                          
num_calib_batches=1,
                                                                          
quantize_mode='full')
         check_params(arg_params, qarg_params, qsym)
@@ -1091,6 +1181,16 @@ def test_quantize_sym_with_calib():
         print('skipped testing quantized_pooling for native cpu since it is 
not supported yet')
         return
 
+    def get_fp32_sym():
+        data = mx.sym.Variable('data')
+        conv = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, 
name='conv')
+        bn = mx.sym.BatchNorm(data=conv, eps=2e-05, fix_gamma=False, 
momentum=0.9, use_global_stats=False, name='bn')
+        act = mx.sym.Activation(data=bn, act_type='relu', name='relu')
+        pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
+        fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, 
name='fc')
+        sym = mx.sym.softmax(fc, name='softmax')
+        return sym
+
     sym = get_fp32_sym()
     offline_params = [name for name in sym.list_arguments()
                       if not name.startswith('data') and not 
name.endswith('label')]
@@ -1120,9 +1220,8 @@ def 
test_quantization_net_with_different_data_inputs_options():
         print('skipped testing 
test_quantization_net_with_different_data_inputs_options for gpu since it is 
not supported yet')
         return
 
-    sym = get_fp32_sym()
-    net = mx.gluon.SymbolBlock(sym, mx.sym.var('data'))
-    initialize_block_params(net, mx.init.Normal(0.2))
+    net = FP32Net()
+    net.initialize()
 
     batch_size = 32
     data_shape = (batch_size, 3, 224, 224)
@@ -1138,10 +1237,10 @@ def 
test_quantization_net_with_different_data_inputs_options():
 
 
     # pass data_shapes as list of DataDescs
-    net2 = mx.gluon.SymbolBlock(sym, mx.sym.var('data'))
-    initialize_block_params(net2, mx.init.Normal(0.2))
+    net2 = FP32Net()
+    net2.initialize()
     data_desc = mx.io.DataDesc('data', data_shape)
-    quantized_net2 = mx.contrib.quant.quantize_net(net,
+    quantized_net2 = mx.contrib.quant.quantize_net(net2,
                                                    quantized_dtype='auto',
                                                    data_shapes=[data_desc],
                                                    ctx=mx.current_context())
@@ -1150,10 +1249,10 @@ def 
test_quantization_net_with_different_data_inputs_options():
 
 
     # pass data as DataLoader
-    net3 = mx.gluon.SymbolBlock(sym, mx.sym.var('data'))
-    initialize_block_params(net3, mx.init.Normal(0.2))
+    net3 = FP32Net()
+    net3.initialize()
     data_loader = mx.gluon.data.DataLoader(random_data, batch_size=batch_size)
-    quantized_net3 = mx.contrib.quant.quantize_net(net,
+    quantized_net3 = mx.contrib.quant.quantize_net(net3,
                                                    quantized_dtype='auto',
                                                    calib_data=data_loader,
                                                    ctx=mx.current_context())

Reply via email to