marcoabreu commented on a change in pull request #10433: [MXNET-290] MKLDNN 
support for model quantization
URL: https://github.com/apache/incubator-mxnet/pull/10433#discussion_r193374697
 
 

 ##########
 File path: tests/python/quantization/test_quantization.py
 ##########
 @@ -120,114 +128,126 @@ def check_requantize(shape, min_calib_range=None, 
max_calib_range=None):
 
 @with_seed()
 def test_quantized_conv():
-    if mx.current_context().device_type != 'gpu':
-        print('skipped testing quantized_conv on cpu since it is not 
implemented yet')
+    if is_test_for_naive_cpu():
+        print('skipped testing quantized_conv for naive cpu since it is not 
implemented yet')
         return
 
-    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, 
no_bias):
-        with mx.Context('gpu', 0):
-            # run fp32 conv
-            data = mx.sym.Variable(name='data', shape=data_shape, 
dtype='float32')
-            conv2d = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
-                                        no_bias=no_bias, cudnn_off=False, 
name='conv2d')
-            arg_shapes, _, _ = conv2d.infer_shape(data=data_shape)
-            arg_names = conv2d.list_arguments()
-            conv_exe_fp32 = conv2d.simple_bind(ctx=mx.current_context(), 
grad_req='null')
-            conv_exe_fp32.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           
shape=data_shape).astype('int32')
-            conv_exe_fp32.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           
shape=arg_shapes[1]).astype('int32')
-            if not no_bias:
-                conv_exe_fp32.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                               
shape=arg_shapes[2]).astype('int32')
-            output = conv_exe_fp32.forward()[0]
-
-            # run quantized conv
-            qdata = mx.sym.Variable(name='qdata', shape=data_shape, 
dtype='int8')
-            qweight = mx.sym.Variable(name='qweight', dtype='int8')
-            min_data = mx.sym.Variable(name='min_data')
-            max_data = mx.sym.Variable(name='max_data')
-            min_weight = mx.sym.Variable(name='min_weight')
-            max_weight = mx.sym.Variable(name='max_weight')
-            quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
-                                                             
max_data=max_data, min_weight=min_weight,
-                                                             
max_weight=max_weight, kernel=kernel,
-                                                             
num_filter=num_filter, pad=pad, stride=stride,
-                                                             no_bias=no_bias)
-            qarg_names = quantized_conv2d.list_arguments()
-            type_dict = None
-            if not no_bias:
-                type_dict = {qarg_names[2]: 'int8'}
-            conv_exe_int8 = 
quantized_conv2d.simple_bind(ctx=mx.current_context(), type_dict=type_dict, 
grad_req='null')
-            conv_exe_int8.arg_dict[qarg_names[0]][:] = 
conv_exe_fp32.arg_dict[arg_names[0]].astype('int8')
-            conv_exe_int8.arg_dict[qarg_names[1]][:] = 
conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
-            quantized_range = 127.0
-            if no_bias:
-                conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range
-                conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range
-            else:
-                conv_exe_int8.arg_dict[qarg_names[2]][:] = 
conv_exe_fp32.arg_dict[arg_names[2]].astype('int8')
-                conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range
-                conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range
-                conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range
-            qoutput, min_range, max_range = conv_exe_int8.forward()
+    if is_test_for_mkldnn():
+        dtype = 'uint8'
+        shift = 127
+    else:
+        dtype = 'int8'
+        shift = 0
 
-            if no_bias:
-                assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-            else:
-                # with adding bias, accuracy loss should not be greater than 
one
-                diff = mx.nd.abs(output - qoutput.astype(output.dtype))
-                cond = mx.nd.lesser(2, diff).sum().asscalar()
-                assert cond == 0
+    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, 
no_bias):
+        # run fp32 conv
+        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
+        conv2d = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
+                                    no_bias=no_bias, cudnn_off=False, 
name='conv2d')
+        arg_shapes, _, _ = conv2d.infer_shape(data=data_shape)
+        arg_names = conv2d.list_arguments()
+        conv_exe_fp32 = conv2d.simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        conv_exe_fp32.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=-127.0 + shift, high=127.0,
+                                                                        
shape=data_shape).astype('int32')
+        conv_exe_fp32.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
+                                                                        
shape=arg_shapes[1]).astype('int32')
+        if not no_bias:
+            conv_exe_fp32.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
+                                                                            
shape=arg_shapes[2]).astype('int32')
+        output = conv_exe_fp32.forward()[0]
+
+        # run quantized conv
+        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=dtype)
+        qweight = mx.sym.Variable(name='qweight', dtype='int8')
+        min_data = mx.sym.Variable(name='min_data')
+        max_data = mx.sym.Variable(name='max_data')
+        min_weight = mx.sym.Variable(name='min_weight')
+        max_weight = mx.sym.Variable(name='max_weight')
+        quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
+                                                            max_data=max_data, 
min_weight=min_weight,
+                                                            
max_weight=max_weight, kernel=kernel,
+                                                            
num_filter=num_filter, pad=pad, stride=stride,
+                                                            no_bias=no_bias)
+        qarg_names = quantized_conv2d.list_arguments()
+        type_dict = None
+        if not no_bias:
+            type_dict = {qarg_names[2]: 'int8'}
+        conv_exe_int8 = quantized_conv2d.simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
+        conv_exe_int8.arg_dict[qarg_names[0]][:] = 
conv_exe_fp32.arg_dict[arg_names[0]].astype(dtype)
+        conv_exe_int8.arg_dict[qarg_names[1]][:] = 
conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
+        quantized_range = 127.0
+        if no_bias:
+            conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range
+            conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range
+        else:
+            conv_exe_int8.arg_dict[qarg_names[2]][:] = 
conv_exe_fp32.arg_dict[arg_names[2]].astype('int8')
+            conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range
+            conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range
+            conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range
+        qoutput, min_range, max_range = conv_exe_int8.forward()
+
+        if no_bias:
+            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
+        else:
+            # with adding bias, accuracy loss should not be greater than one
+            diff = mx.nd.abs(output - qoutput.astype(output.dtype))
+            cond = mx.nd.lesser(2, diff).sum().asscalar()
+            assert cond == 0
 
     check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), True)
     check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), False)
 
 
 @with_seed()
 def test_quantized_pooling():
-    if mx.current_context().device_type != 'gpu':
-        print('skipped testing quantized_pooling on cpu since it is not 
implemented yet')
+    if is_test_for_naive_cpu():
+        print('skipped testing quantized_pooling for naive cpu since it is not 
implemented yet')
         return
 
-    def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, 
global_pool):
-        with mx.Context('gpu', 0):
-            data = mx.sym.Variable(name='data', shape=data_shape, 
dtype='float32')
-            pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, 
stride=stride,
-                                          pool_type=pool_type, 
global_pool=global_pool, cudnn_off=False)
-            arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape)
-            arg_names = pooling_fp32.list_arguments()
-            pooling_fp32_exe = 
pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
-            pooling_fp32_exe.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                              
shape=data_shape).astype('int32')
-            output = pooling_fp32_exe.forward()[0]
-
-            qdata = mx.sym.Variable(name='qdata', shape=data_shape, 
dtype='int8')
-            min_data = mx.sym.Variable(name='min_data')
-            max_data = mx.sym.Variable(name='max_data')
-            quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, 
min_data=min_data,
-                                                                 
max_data=max_data, kernel=kernel,
-                                                                 pad=pad, 
stride=stride, pool_type=pool_type,
-                                                                 
global_pool=global_pool)
-            pooling_int8_exe = 
quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null')
-            qarg_names = quantized_pooling.list_arguments()
-            pooling_int8_exe.arg_dict[qarg_names[0]][:] = 
pooling_fp32_exe.arg_dict[arg_names[0]].astype('int8')
-            quantized_range = 127.0
-            pooling_int8_exe.arg_dict[qarg_names[1]][:] = -quantized_range
-            pooling_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range
-            qoutput, min_range, max_range = pooling_int8_exe.forward()
+    if is_test_for_mkldnn():
+        dtype = 'uint8'
 
 Review comment:
   Why different dtypes? the test is written to target int8, so why test uint8 
now?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to