reminisce closed pull request #11937: Fix quantized graphpass bug URL: https://github.com/apache/incubator-mxnet/pull/11937
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 5376a0ee9f1..10834868d2b 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -221,6 +221,9 @@ Graph QuantizeGraph(Graph &&src) { new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); mirror_map[e.node.get()] = std::move(dequantize_node); + } else if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_quantize") { + new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); } else { new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, e.version}); } diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 359bbee569f..bfae58e49d0 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -396,6 +396,17 @@ def get_fp32_sym(): out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') return sym +def get_fp32_residual(): + data = mx.sym.Variable('data') + conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0), + no_bias=True, name='conv') + bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn') + act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu') + pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool') + fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False, + out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') + return sym @with_seed() def test_quantize_model(): @@ -463,6 +474,101 @@ def check_qsym_qdtype(qsym, qdtype): for qdtype in ['int8', 'uint8']: check_quantize_model(qdtype) +@with_seed() +def test_quantize_residual_unit(): + def check_quantize_model(qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_residual_unit for native cpu since it is not supported yet') + return + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing quantized_residual_unit for mkldnn cpu int8 since it is not supported yet') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantized_residual_unit for gpu uint8 since it is not supported yet') + return + + def check_params(params, qparams, qsym=None): + if qsym is None: + assert len(params) == len(qparams) + for k, v in params.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + else: + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params) + assert len(qparams) == len(qparams_ground_truth) + for k, v in qparams_ground_truth.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + + def check_qsym_calibrated(qsym): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('requantize_') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + + def check_qsym_qdtype(qsym, qdtype): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('_quantize') != -1: + assert 'out_type' in v + assert v['out_type'] == qdtype + + def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): + mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) + mod.set_params(qarg_params, qaux_params) + data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + + + sym = get_fp32_residual() + mod = Module(symbol=sym) + batch_size = 4 + data_shape = (batch_size, 4, 10, 10) + label_shape = (batch_size, 10) + mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) + mod.init_params() + arg_params, aux_params = mod.get_params() + excluded_sym_names = [] + if mx.current_context() == mx.cpu(): + excluded_sym_names += ['fc'] + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='none') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape) + + calib_data = mx.nd.random.uniform(shape=data_shape) + calib_data = NDArrayIter(data=calib_data) + calib_data = DummyIter(calib_data) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_examples=20) + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + check_qsym_qdtype(qsym, qdtype) + check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape) + + for qdtype in ['int8', 'uint8']: + check_quantize_model(qdtype) @with_seed() def test_quantize_sym_with_calib(): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services