bartekkuncer commented on a change in pull request #20227:
URL: https://github.com/apache/incubator-mxnet/pull/20227#discussion_r638713172



##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -229,69 +219,106 @@ def check_quantized_conv(data_shape, kernel, num_filter, 
pad, stride, dilate, no
             return
 
         # run fp32 conv
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        conv = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
-                                  dilate=dilate, no_bias=no_bias, 
cudnn_off=False, name='conv')
-        arg_shapes, _, _ = conv.infer_shape(data=data_shape)
-        arg_names = conv.list_arguments()
-        conv_exe_fp32 = conv._simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        if len(data_shape) == 4:
+            convfp32 = mx.gluon.nn.Conv2D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        elif len(data_shape) == 5:
+            convfp32 = mx.gluon.nn.Conv3D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        else:
+            print('unsupported shape')
+            assert 1 == 0
+
         if qdtype == 'uint8':
             data_low = 0.0
             data_high = 127.0
         else:
             data_low = -127.0
             data_high = 127.0
-        conv_exe_fp32.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
-                                                                       
shape=data_shape).astype('int32')
-        conv_exe_fp32.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                       
shape=arg_shapes[1]).astype('int32')
-        if not no_bias:
-            conv_exe_fp32.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           
shape=arg_shapes[2]).astype('int32')
-        output = conv_exe_fp32.forward()[0]
+
+        convfp32.initialize()
+        input_data = mx.nd.random.uniform(low=data_low,
+                                          high=data_high,
+                                          shape=data_shape
+                                         ).astype('int32').astype('float32')
+        convfp32(input_data) # initialize params
+        mx.nd.waitall()
+        fp32_params = convfp32.collect_params()
+        new_args = dict()
+        new_args['weight'] = mx.nd.random.uniform(low=-127.0,
+                                                  high=127.0,
+                                                  
shape=fp32_params['weight'].shape
+                                                 
).astype('int32').astype('float32')
+        if use_bias:
+           new_args['bias'] = mx.nd.random.uniform(low=-127.0,
+                                                   high=127.0,
+                                                   
shape=fp32_params['bias'].shape
+                                                  
).astype('int32').astype('float32')
+        convfp32.load_dict(new_args, cast_dtype=True, dtype_source='saved')
+
+        output = convfp32(input_data)
 
         # run quantized conv
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        qweight = mx.sym.Variable(name='qweight', dtype='int8')
-        min_data = mx.sym.Variable(name='min_data')
-        max_data = mx.sym.Variable(name='max_data')
-        min_weight = mx.sym.Variable(name='min_weight')
-        max_weight = mx.sym.Variable(name='max_weight')
-        quantized_conv = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
-                                                       max_data=max_data, 
min_weight=min_weight,
-                                                       max_weight=max_weight, 
kernel=kernel,
-                                                       num_filter=num_filter, 
pad=pad, stride=stride,
-                                                       dilate=dilate, 
no_bias=no_bias)
-        qarg_names = quantized_conv.list_arguments()
-        type_dict = None
-        if not no_bias:
-            type_dict = {qarg_names[2]: 'int8'}
-        conv_exe_int8 = quantized_conv._simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
-        conv_exe_int8.arg_dict[qarg_names[0]][:] = 
conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype)
-        conv_exe_int8.arg_dict[qarg_names[1]][:] = 
conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
+        class QuantConv(mx.gluon.nn.HybridBlock):
+            def __init__(self, channels, kernel_size, strides=(1, 1),
+                         padding=(0, 0), dilation=(1, 1), use_bias=True, 
**kwargs):
+                super(QuantConv, self).__init__(**kwargs)
+                self.use_bias = use_bias
+                self._kwargs = {
+                    'kernel': kernel_size, 'stride': strides, 'dilate': 
dilation,
+                    'pad': padding, 'num_filter': channels, 'no_bias': not 
use_bias, 'num_group': 1,
+                    'layout': 'NCHW'}
+
+                self.min_data = mx.gluon.Parameter('min_data', 
dtype='float32', allow_deferred_init=True)
+                self.max_data = mx.gluon.Parameter('max_data', 
dtype='float32', allow_deferred_init=True)
+
+                self.weight = mx.gluon.Parameter('weight', dtype='int8', 
allow_deferred_init=True)
+                self.min_weight = mx.gluon.Parameter('min_weight', 
dtype='float32', allow_deferred_init=True)
+                self.max_weight = mx.gluon.Parameter('max_weight', 
dtype='float32', allow_deferred_init=True)
+                
+                if use_bias:
+                    self.bias = mx.gluon.Parameter('bias', dtype='int8', 
allow_deferred_init=True)
+                    self.min_bias = mx.gluon.Parameter('min_bias', 
dtype='float32', allow_deferred_init=True)
+                    self.max_bias = mx.gluon.Parameter('max_bias', 
dtype='float32', allow_deferred_init=True)
+
+            def hybrid_forward(self, F, x, weight, bias=None, min_data=None, 
max_data=None,
+                               min_weight=None, max_weight=None, 
min_bias=None, max_bias=None):
+                out = F.contrib.quantized_conv(data=x, weight=weight, 
bias=bias, 
+                                               min_data=min_data, 
max_data=max_data,
+                                               min_weight=min_weight, 
max_weight=max_weight,
+                                               min_bias=min_bias, 
max_bias=max_bias,
+                                               **self._kwargs)
+                return out
+
+        convint8 = QuantConv(channels=num_filter, kernel_size=kernel, 
strides=stride,
+                             padding=pad, dilation=dilate, use_bias=use_bias)
+
         quantized_range = 127.0
-        if no_bias:
-            conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range
-            conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range
-        else:
-            conv_exe_int8.arg_dict[qarg_names[2]][:] = 
conv_exe_fp32.arg_dict[arg_names[2]].astype('int8')
-            conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range
-            conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range
-            conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range
-            conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range
-        qoutput, min_range, max_range = conv_exe_int8.forward()
-
-        if no_bias:
-            assert_almost_equal(output.asnumpy(), qoutput.asnumpy(), atol = 1)
-        else:
+        q_args = {

Review comment:
       Maybe 'qargs' to stay consistent with 'qdtype' and 'qoutput'? Or, as it 
is more readable with the underscore, change 'qdtype' and 'qoutput' to 
'q_dtype' and 'q_output' respectively?

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -229,69 +219,106 @@ def check_quantized_conv(data_shape, kernel, num_filter, 
pad, stride, dilate, no
             return
 
         # run fp32 conv
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        conv = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
-                                  dilate=dilate, no_bias=no_bias, 
cudnn_off=False, name='conv')
-        arg_shapes, _, _ = conv.infer_shape(data=data_shape)
-        arg_names = conv.list_arguments()
-        conv_exe_fp32 = conv._simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        if len(data_shape) == 4:
+            convfp32 = mx.gluon.nn.Conv2D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        elif len(data_shape) == 5:
+            convfp32 = mx.gluon.nn.Conv3D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        else:
+            print('unsupported shape')
+            assert 1 == 0

Review comment:
       Why not just false?

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -555,57 +612,79 @@ def maxabs(a, b):
             weight_scale = int8_range / weight_range
             bias_int32_rescale = data_scale * weight_scale / bias_scale
             new_bias = mx.nd.cast(bias, dtype='float32') * bias_int32_rescale
-            fc_fp32_exe.arg_dict[arg_names[2]][:] = new_bias.astype('int32')
-
-        output = fc_fp32_exe.forward()[0]
-
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, 
num_hidden=num_hidden,
-                                                           no_bias=no_bias, 
flatten=flatten)
-        qarg_names = fc_int8.list_arguments()
-        type_dict = {qarg_names[1]: 'int8'}
-        if not no_bias:
-            type_dict.update({qarg_names[2]: 'int8'})
-        fc_int8_exe = fc_int8._simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
-        fc_int8_exe.arg_dict[qarg_names[0]][:] = 
fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
-        fc_int8_exe.arg_dict[qarg_names[1]][:] = 
fc_fp32_exe.arg_dict[arg_names[1]].astype('int8')
-        if no_bias:
-            fc_int8_exe.arg_dict[qarg_names[2]][:] = -data_range
-            fc_int8_exe.arg_dict[qarg_names[3]][:] = data_range
-            fc_int8_exe.arg_dict[qarg_names[4]][:] = -weight_range
-            fc_int8_exe.arg_dict[qarg_names[5]][:] = weight_range
-        else:
-            fc_int8_exe.arg_dict[qarg_names[2]][:] = bias.astype('int8')
-            fc_int8_exe.arg_dict[qarg_names[3]][:] = -data_range
-            fc_int8_exe.arg_dict[qarg_names[4]][:] = data_range
-            fc_int8_exe.arg_dict[qarg_names[5]][:] = -weight_range
-            fc_int8_exe.arg_dict[qarg_names[6]][:] = weight_range
-            fc_int8_exe.arg_dict[qarg_names[7]][:] = -bias_range
-            fc_int8_exe.arg_dict[qarg_names[8]][:] = bias_range
-        qoutput, min_range, max_range = fc_int8_exe.forward()
-
-        if no_bias:
-            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-        else:
+            new_args['bias'] = new_bias.astype('int32').astype('float32')
+
+        fc_fp32.load_dict(new_args, cast_dtype=True, dtype_source='saved')
+        output = fc_fp32(data)
+
+        class QuantFC(mx.gluon.nn.HybridBlock):
+            def __init__(self, num_hidden, use_bias, flatten, **kwargs):
+                super(QuantFC, self).__init__(**kwargs)
+                self.use_bias = use_bias
+                self._kwargs = {'num_hidden': num_hidden, 'no_bias': not 
use_bias, 'flatten': flatten}

Review comment:
       Maybe try aligning these variables declarations?

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -555,57 +612,79 @@ def maxabs(a, b):
             weight_scale = int8_range / weight_range
             bias_int32_rescale = data_scale * weight_scale / bias_scale
             new_bias = mx.nd.cast(bias, dtype='float32') * bias_int32_rescale
-            fc_fp32_exe.arg_dict[arg_names[2]][:] = new_bias.astype('int32')
-
-        output = fc_fp32_exe.forward()[0]
-
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, 
num_hidden=num_hidden,
-                                                           no_bias=no_bias, 
flatten=flatten)
-        qarg_names = fc_int8.list_arguments()
-        type_dict = {qarg_names[1]: 'int8'}
-        if not no_bias:
-            type_dict.update({qarg_names[2]: 'int8'})
-        fc_int8_exe = fc_int8._simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
-        fc_int8_exe.arg_dict[qarg_names[0]][:] = 
fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
-        fc_int8_exe.arg_dict[qarg_names[1]][:] = 
fc_fp32_exe.arg_dict[arg_names[1]].astype('int8')
-        if no_bias:
-            fc_int8_exe.arg_dict[qarg_names[2]][:] = -data_range
-            fc_int8_exe.arg_dict[qarg_names[3]][:] = data_range
-            fc_int8_exe.arg_dict[qarg_names[4]][:] = -weight_range
-            fc_int8_exe.arg_dict[qarg_names[5]][:] = weight_range
-        else:
-            fc_int8_exe.arg_dict[qarg_names[2]][:] = bias.astype('int8')
-            fc_int8_exe.arg_dict[qarg_names[3]][:] = -data_range
-            fc_int8_exe.arg_dict[qarg_names[4]][:] = data_range
-            fc_int8_exe.arg_dict[qarg_names[5]][:] = -weight_range
-            fc_int8_exe.arg_dict[qarg_names[6]][:] = weight_range
-            fc_int8_exe.arg_dict[qarg_names[7]][:] = -bias_range
-            fc_int8_exe.arg_dict[qarg_names[8]][:] = bias_range
-        qoutput, min_range, max_range = fc_int8_exe.forward()
-
-        if no_bias:
-            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-        else:
+            new_args['bias'] = new_bias.astype('int32').astype('float32')
+
+        fc_fp32.load_dict(new_args, cast_dtype=True, dtype_source='saved')
+        output = fc_fp32(data)
+
+        class QuantFC(mx.gluon.nn.HybridBlock):
+            def __init__(self, num_hidden, use_bias, flatten, **kwargs):
+                super(QuantFC, self).__init__(**kwargs)
+                self.use_bias = use_bias
+                self._kwargs = {'num_hidden': num_hidden, 'no_bias': not 
use_bias, 'flatten': flatten}
+
+                self.min_data = mx.gluon.Parameter('min_data', 
dtype='float32', allow_deferred_init=True)
+                self.max_data = mx.gluon.Parameter('max_data', 
dtype='float32', allow_deferred_init=True)
+
+                self.weight = mx.gluon.Parameter('weight', dtype='int8', 
allow_deferred_init=True)
+                self.min_weight = mx.gluon.Parameter('min_weight', 
dtype='float32', allow_deferred_init=True)
+                self.max_weight = mx.gluon.Parameter('max_weight', 
dtype='float32', allow_deferred_init=True)
+
+                if use_bias:
+                    self.bias = mx.gluon.Parameter('bias', dtype='int8', 
allow_deferred_init=True)
+                    self.min_bias = mx.gluon.Parameter('min_bias', 
dtype='float32', allow_deferred_init=True)
+                    self.max_bias = mx.gluon.Parameter('max_bias', 
dtype='float32', allow_deferred_init=True)
+
+            def hybrid_forward(self, F, x, weight, bias=None, min_data=None, 
max_data=None,
+                               min_weight=None, max_weight=None, 
min_bias=None, max_bias=None):
+                out = F.contrib.quantized_fully_connected(data=x, 
weight=weight, bias=bias, 
+                                               min_data=min_data, 
max_data=max_data,

Review comment:
       Unaligned params.

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -201,10 +195,10 @@ def check_requantize_with_symbol(shape, 
min_calib_range=None, max_calib_range=No
         assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))
 
     # test with symbol API.

Review comment:
       This comment and the one below are inconsistent. I suggest to unify them.

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -229,69 +219,106 @@ def check_quantized_conv(data_shape, kernel, num_filter, 
pad, stride, dilate, no
             return
 
         # run fp32 conv
-        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        conv = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
-                                  dilate=dilate, no_bias=no_bias, 
cudnn_off=False, name='conv')
-        arg_shapes, _, _ = conv.infer_shape(data=data_shape)
-        arg_names = conv.list_arguments()
-        conv_exe_fp32 = conv._simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        if len(data_shape) == 4:
+            convfp32 = mx.gluon.nn.Conv2D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        elif len(data_shape) == 5:
+            convfp32 = mx.gluon.nn.Conv3D(channels=num_filter, 
kernel_size=kernel, strides=stride,
+                                          padding=pad, dilation=dilate, 
use_bias=use_bias)
+        else:
+            print('unsupported shape')
+            assert 1 == 0
+
         if qdtype == 'uint8':
             data_low = 0.0
             data_high = 127.0
         else:
             data_low = -127.0
             data_high = 127.0
-        conv_exe_fp32.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
-                                                                       
shape=data_shape).astype('int32')
-        conv_exe_fp32.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                       
shape=arg_shapes[1]).astype('int32')
-        if not no_bias:
-            conv_exe_fp32.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           
shape=arg_shapes[2]).astype('int32')
-        output = conv_exe_fp32.forward()[0]
+
+        convfp32.initialize()
+        input_data = mx.nd.random.uniform(low=data_low,
+                                          high=data_high,
+                                          shape=data_shape
+                                         ).astype('int32').astype('float32')
+        convfp32(input_data) # initialize params
+        mx.nd.waitall()
+        fp32_params = convfp32.collect_params()
+        new_args = dict()
+        new_args['weight'] = mx.nd.random.uniform(low=-127.0,
+                                                  high=127.0,
+                                                  
shape=fp32_params['weight'].shape
+                                                 
).astype('int32').astype('float32')
+        if use_bias:
+           new_args['bias'] = mx.nd.random.uniform(low=-127.0,
+                                                   high=127.0,
+                                                   
shape=fp32_params['bias'].shape
+                                                  
).astype('int32').astype('float32')
+        convfp32.load_dict(new_args, cast_dtype=True, dtype_source='saved')
+
+        output = convfp32(input_data)
 
         # run quantized conv
-        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
-        qweight = mx.sym.Variable(name='qweight', dtype='int8')
-        min_data = mx.sym.Variable(name='min_data')
-        max_data = mx.sym.Variable(name='max_data')
-        min_weight = mx.sym.Variable(name='min_weight')
-        max_weight = mx.sym.Variable(name='max_weight')
-        quantized_conv = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
-                                                       max_data=max_data, 
min_weight=min_weight,
-                                                       max_weight=max_weight, 
kernel=kernel,
-                                                       num_filter=num_filter, 
pad=pad, stride=stride,
-                                                       dilate=dilate, 
no_bias=no_bias)
-        qarg_names = quantized_conv.list_arguments()
-        type_dict = None
-        if not no_bias:
-            type_dict = {qarg_names[2]: 'int8'}
-        conv_exe_int8 = quantized_conv._simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
-        conv_exe_int8.arg_dict[qarg_names[0]][:] = 
conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype)
-        conv_exe_int8.arg_dict[qarg_names[1]][:] = 
conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
+        class QuantConv(mx.gluon.nn.HybridBlock):
+            def __init__(self, channels, kernel_size, strides=(1, 1),
+                         padding=(0, 0), dilation=(1, 1), use_bias=True, 
**kwargs):
+                super(QuantConv, self).__init__(**kwargs)
+                self.use_bias = use_bias
+                self._kwargs = {
+                    'kernel': kernel_size, 'stride': strides, 'dilate': 
dilation,
+                    'pad': padding, 'num_filter': channels, 'no_bias': not 
use_bias, 'num_group': 1,
+                    'layout': 'NCHW'}
+
+                self.min_data = mx.gluon.Parameter('min_data', 
dtype='float32', allow_deferred_init=True)

Review comment:
       Maybe try aligning these variables declarations?

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -760,48 +864,38 @@ def check_quantized_bn(data_shape, qdtype):
             data_high = 127.0
 
         # run fp32 bn
-        data_sym = mx.sym.Variable(name='data', shape=data_shape, 
dtype='float32')
-        bn_fp32 = mx.sym.BatchNorm(data=data_sym, name='bn', 
use_global_stats=True, fix_gamma=False)
-        arg_shapes, out_shapes, aux_shapes = 
bn_fp32.infer_shape(data=data_shape)
-        arg_names = bn_fp32.list_arguments()
-        aux_names = bn_fp32.list_auxiliary_states()
+        bn_fp32 = mx.gluon.nn.BatchNorm(use_global_stats=True, scale=True)
+        data = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape)
+        bn_fp32.initialize()
+        bn_fp32.hybridize()
+        bn_fp32(data)
+        fp32_params = bn_fp32.collect_params()
+        
         data = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape)
-        gamma = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=arg_shapes[1])
-        beta = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=arg_shapes[2])
+        gamma = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=fp32_params['gamma'].shape)
+        beta = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=fp32_params['beta'].shape)
         moving_mean, moving_var = get_mean_var(data)
+        new_params = {
+            'gamma':gamma,
+            'beta':beta,
+            'running_mean': moving_mean,

Review comment:
       Maybe rename 'moving_*' vars to 'running_*' vars for consistency?

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -314,42 +341,46 @@ def check_quantized_elemwise_add(data_shape, qtype):
             print('skipped testing quantized_elemwise_add for gpu since it is 
not supported yet')
             return
 
-        dataA = mx.sym.Variable(name='dataA', shape=data_shape, 
dtype='float32')
-        dataB = mx.sym.Variable(name='dataB', shape=data_shape, 
dtype='float32')
-        elemwise_add_fp32 = mx.sym.elemwise_add(dataA, dataB)
-        arg_names = elemwise_add_fp32.list_arguments()
-        elemwise_add_fp32_exe = 
elemwise_add_fp32._simple_bind(ctx=mx.current_context(), grad_req='null')
+        class ElemwiseSumBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(ElemwiseSumBlock, self).__init__(**kwargs)
+
+            def hybrid_forward(self, F, dataA, dataB):
+                return F.elemwise_add(dataA, dataB)
+
+        class QuantElemwiseSumBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(QuantElemwiseSumBlock, self).__init__(**kwargs)
+
+            def hybrid_forward(self, F, dataA, dataB, dataA_min, dataA_max, 
dataB_min, dataB_max):
+                return F.contrib.quantized_elemwise_add(dataA, dataB, 
dataA_min, dataA_max, dataB_min, dataB_max)
+
+        elemwise_add_fp32 = ElemwiseSumBlock()
+
         if qtype == 'uint8':
             data_low = 0.0
             data_high = 255.0
         else:
             data_low = -127.0
             data_high = 127.0
 
-        dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32')
-        dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32')
-        elemwise_add_fp32_exe.arg_dict[arg_names[0]][:] = dataA_val
-
-        elemwise_add_fp32_exe.arg_dict[arg_names[1]][:] = dataB_val
-
-        output = elemwise_add_fp32_exe.forward()[0]
-        qdataA = mx.sym.Variable(name='qdataA', shape=data_shape, dtype=qtype)
-        qdataB = mx.sym.Variable(name='qdataB', shape=data_shape, dtype=qtype)
-        min_dataA = mx.sym.Variable(name='min_dataA', dtype='float32')
-        max_dataA = mx.sym.Variable(name='max_dataA', dtype='float32')
-        min_dataB = mx.sym.Variable(name='min_dataB', dtype='float32')
-        max_dataB = mx.sym.Variable(name='max_dataB', dtype='float32')
-        quantized_elemwise_add = mx.sym.contrib.quantized_elemwise_add(qdataA, 
qdataB, min_dataA, max_dataA, min_dataB, max_dataB)
-        elemwise_add_int8_exe = 
quantized_elemwise_add._simple_bind(ctx=mx.current_context(), grad_req='null')
-        qarg_names = quantized_elemwise_add.list_arguments()
-        elemwise_add_int8_exe.arg_dict[qarg_names[0]][:] = 
elemwise_add_fp32_exe.arg_dict[arg_names[0]].astype(qtype)
-        elemwise_add_int8_exe.arg_dict[qarg_names[1]][:] = 
elemwise_add_fp32_exe.arg_dict[arg_names[1]].astype(qtype)
+        dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32').astype('float32')
+        dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, 
shape=data_shape).astype('int32').astype('float32')
+
+        output = elemwise_add_fp32(dataA_val, dataB_val)
+
+        #run quantized
+        quantized_elemwise_add = QuantElemwiseSumBlock()
+        dataA_val_int8 = dataA_val.astype(qtype)
+        dataB_val_int8 = dataB_val.astype(qtype)
         quantized_range = 127.0
-        elemwise_add_int8_exe.arg_dict[qarg_names[2]][:] = data_low
-        elemwise_add_int8_exe.arg_dict[qarg_names[3]][:] = data_high
-        elemwise_add_int8_exe.arg_dict[qarg_names[4]][:] = data_low
-        elemwise_add_int8_exe.arg_dict[qarg_names[5]][:] = data_high
-        qoutput, min_range, max_range = elemwise_add_int8_exe.forward()
+        min_dataA = mx.nd.array([data_low])
+        max_dataA = mx.nd.array([data_high])
+        min_dataB = mx.nd.array([data_low])
+        max_dataB = mx.nd.array([data_high])
+        qoutput, min_range, max_range = quantized_elemwise_add(dataA_val_int8, 
dataB_val_int8,
+                                                               min_dataA, 
max_dataA,
+                                                               min_dataB, 
max_dataB)
         int8_rslt = qoutput.astype(output.dtype)*max_range/0x7fffffff

Review comment:
       Add some spaces here please. Should not there be double '/' as division 
operator as we are working here on integers?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to