reminisce closed pull request #11937: Fix quantized graphpass bug
URL: https://github.com/apache/incubator-mxnet/pull/11937
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/quantization/quantize_graph_pass.cc 
b/src/operator/quantization/quantize_graph_pass.cc
index 5376a0ee9f1..10834868d2b 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -221,6 +221,9 @@ Graph QuantizeGraph(Graph &&src) {
 
           new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0});
           mirror_map[e.node.get()] = std::move(dequantize_node);
+        } else if (mirror_node->op() != nullptr
+                   && mirror_node->op()->name == "_contrib_quantize") {
+          new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, 
e.index, e.version});
         } else {
           new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, 
e.version});
         }
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index 359bbee569f..bfae58e49d0 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -396,6 +396,17 @@ def get_fp32_sym():
                                out_grad=False, preserve_shape=False, 
use_ignore=False, name='softmax')
     return sym
 
+def get_fp32_residual():
+    data = mx.sym.Variable('data')
+    conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
+                              no_bias=True, name='conv')
+    bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, 
name='bn')
+    act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu')
+    pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
+    fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc')
+    sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, 
multi_output=False,
+                               out_grad=False, preserve_shape=False, 
use_ignore=False, name='softmax')
+    return sym 
 
 @with_seed()
 def test_quantize_model():
@@ -463,6 +474,101 @@ def check_qsym_qdtype(qsym, qdtype):
     for qdtype in ['int8', 'uint8']:
         check_quantize_model(qdtype)
 
+@with_seed()
+def test_quantize_residual_unit():
+    def check_quantize_model(qdtype):
+        if is_test_for_native_cpu():
+            print('skipped testing quantized_residual_unit for native cpu 
since it is not supported yet')
+            return
+        elif qdtype == 'int8' and is_test_for_mkldnn():
+            print('skipped testing quantized_residual_unit for mkldnn cpu int8 
since it is not supported yet')
+            return
+        elif qdtype == 'uint8' and is_test_for_gpu():
+            print('skipped testing quantized_residual_unit for gpu uint8 since 
it is not supported yet')
+            return
+
+        def check_params(params, qparams, qsym=None):
+            if qsym is None:
+                assert len(params) == len(qparams)
+                for k, v in params.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+            else:
+                qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, 
params)
+                assert len(qparams) == len(qparams_ground_truth)
+                for k, v in qparams_ground_truth.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+
+        def check_qsym_calibrated(qsym):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('requantize_') != -1:
+                    assert 'min_calib_range' in v
+                    assert 'max_calib_range' in v
+
+        def check_qsym_qdtype(qsym, qdtype):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('_quantize') != -1:
+                    assert 'out_type' in v
+                    assert v['out_type'] == qdtype
+
+        def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, 
label_shape):
+            mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+            mod.bind(for_training=False,
+                     data_shapes=[('data', data_shape)],
+                     label_shapes=[('softmax_label', label_shape)])
+            mod.set_params(qarg_params, qaux_params)
+            data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in 
mod.data_shapes]
+            batch = mx.io.DataBatch(data, [])
+            mod.forward(batch, is_train=False)
+            for output in mod.get_outputs():
+                output.wait_to_read()
+             
+
+        sym = get_fp32_residual()
+        mod = Module(symbol=sym)
+        batch_size = 4
+        data_shape = (batch_size, 4, 10, 10)
+        label_shape = (batch_size, 10)
+        mod.bind(data_shapes=[('data', data_shape)], 
label_shapes=[('softmax_label', label_shape)])
+        mod.init_params()
+        arg_params, aux_params = mod.get_params()
+        excluded_sym_names = []
+        if mx.current_context() == mx.cpu():
+           excluded_sym_names += ['fc']
+        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
+                                                                         
arg_params=arg_params,
+                                                                         
aux_params=aux_params,
+                                                                         
excluded_sym_names=excluded_sym_names,
+                                                                         
ctx=mx.current_context(),
+                                                                         
quantized_dtype=qdtype,
+                                                                         
calib_mode='none')
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, 
label_shape)
+
+        calib_data = mx.nd.random.uniform(shape=data_shape)
+        calib_data = NDArrayIter(data=calib_data)
+        calib_data = DummyIter(calib_data)
+        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
+                                                                         
arg_params=arg_params,
+                                                                         
aux_params=aux_params,
+                                                                         
excluded_sym_names=excluded_sym_names,
+                                                                         
ctx=mx.current_context(),
+                                                                         
quantized_dtype=qdtype,
+                                                                         
calib_mode='naive',
+                                                                         
calib_data=calib_data,
+                                                                         
num_calib_examples=20)
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_calibrated(qsym)
+        check_qsym_qdtype(qsym, qdtype)
+        check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, 
label_shape)
+
+    for qdtype in ['int8', 'uint8']:
+        check_quantize_model(qdtype)
 
 @with_seed()
 def test_quantize_sym_with_calib():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to