This is an automated email from the ASF dual-hosted git repository.
apeforest pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8604c3c [Mxnet-1397] Support symbolic api for requantize and
dequantize (#14749)
8604c3c is described below
commit 8604c3c4c10468401e9535e5a4d4f4638fd013ef
Author: shoubhik <[email protected]>
AuthorDate: Wed Apr 24 14:06:29 2019 -0700
[Mxnet-1397] Support symbolic api for requantize and dequantize (#14749)
* Adding support for symbolic API for requantize and dequantize
* Adding name to contributors list
* Removing redundant code
* Addressing indentation and using current_context() instead of cpu()
* merge from master
* merge from master
---
CONTRIBUTORS.md | 1 +
src/operator/quantization/dequantize.cc | 4 ++
src/operator/quantization/requantize.cc | 4 ++
tests/python/quantization/test_quantization.py | 82 ++++++++++++++++++++++----
4 files changed, 80 insertions(+), 11 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 9508f1e..be497e5 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -236,6 +236,7 @@ List of Contributors
* [Zhennan Qin](https://github.com/ZhennanQin)
* [Zhiyuan Huang](https://github.com/huangzhiyuan)
* [Zak Jost](https://github.com/zjost)
+* [Shoubhik Bhattacharya](https://github.com/shoubhik)
* [Zach Kimberg](https://github.com/zachgk)
Label Bot
diff --git a/src/operator/quantization/dequantize.cc
b/src/operator/quantization/dequantize.cc
index dd433e4..e8e2cd9 100644
--- a/src/operator/quantization/dequantize.cc
+++ b/src/operator/quantization/dequantize.cc
@@ -84,6 +84,10 @@ by keep zero centered for the quantized value:
.set_attr_parser(ParamParser<DequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data", "min_range", "max_range"};
+ })
.set_attr<mxnet::FInferShape>("FInferShape", DequantizeShape)
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
diff --git a/src/operator/quantization/requantize.cc
b/src/operator/quantization/requantize.cc
index 4807226..4368238 100644
--- a/src/operator/quantization/requantize.cc
+++ b/src/operator/quantization/requantize.cc
@@ -61,6 +61,10 @@ inference accuracy.
.set_attr_parser(ParamParser<RequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(3)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data", "min_range", "max_range"};
+ })
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
.set_attr<nnvm::FInferType>("FInferType", RequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
diff --git a/tests/python/quantization/test_quantization.py
b/tests/python/quantization/test_quantization.py
index 2761e77..3c8cc42 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -63,19 +63,45 @@ def test_quantize_float32_to_int8():
@with_seed()
def test_dequantize_int8_to_float32():
+
+ def get_test_data(real_range, qdata_np):
+ qdata = mx.nd.array(qdata_np, dtype=np.int8)
+ min_range = mx.nd.array([-real_range], dtype=np.float32)
+ max_range = mx.nd.array([real_range], dtype=np.float32)
+ return qdata, min_range, max_range
+
+ def baseline_dequantization(qdata, real_range, qdata_np):
+ quantized_range = 127.0
+ scale = real_range / quantized_range
+ data_np = qdata_np * scale
+ return data_np
+
+ def test_nd_array_dequantization(qdata, min_range, max_range,
expected_result):
+ data = mx.nd.contrib.dequantize(qdata, min_range, max_range,
out_type='float32')
+ assert data.dtype == np.float32
+ assert_almost_equal(data.asnumpy(), expected_result)
+
+ def test_symbolic_api_dequantization(qdata, min_range, max_range,
expected_result):
+ sym_data = mx.sym.Variable('data')
+ sym_min_range = mx.sym.Variable('min_range')
+ sym_max_range = mx.sym.Variable('max_range')
+ dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range,
+ sym_max_range, out_type='float32')
+ out = dequant.bind(ctx=mx.current_context(),
+ args={'data':qdata, 'min_range':min_range,
'max_range':max_range})
+ data = out.forward()[0]
+ assert data.dtype == np.float32
+ assert_almost_equal(data.asnumpy(), expected_result)
+
+ real_range = 402.3347
shape = rand_shape_nd(4)
qdata_np = np.random.uniform(low=-127, high=127,
size=shape).astype(dtype=np.int8)
- qdata = mx.nd.array(qdata_np, dtype=np.int8)
- real_range = 402.3347
- min_range = mx.nd.array([-real_range], dtype=np.float32)
- max_range = mx.nd.array([real_range], dtype=np.float32)
- data = mx.nd.contrib.dequantize(qdata, min_range, max_range,
out_type='float32')
- quantized_range = 127.0
- scale = real_range / quantized_range
- assert data.dtype == np.float32
- data_np = qdata_np * scale
- assert_almost_equal(data.asnumpy(), data_np)
-
+ qdata, min_range, max_range = get_test_data(real_range, qdata_np)
+ expected_result = baseline_dequantization(qdata, real_range, qdata_np)
+ # test nd array implementation.
+ test_nd_array_dequantization(qdata, min_range, max_range, expected_result)
+ # test symbolic api implementaion.
+ test_symbolic_api_dequantization(qdata, min_range, max_range,
expected_result)
@with_seed()
def test_requantize_int32_to_int8():
@@ -124,7 +150,41 @@ def test_requantize_int32_to_int8():
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1)
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))
+
+ def check_requantize_with_symbol(shape, min_calib_range=None,
max_calib_range=None):
+ qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0,
shape=shape).astype('int32')
+ min_range = mx.nd.array([-1010.0])
+ max_range = mx.nd.array([1020.0])
+ sym_data = mx.sym.Variable('data')
+ sym_min_range = mx.sym.Variable('min_range')
+ sym_max_range = mx.sym.Variable('max_range')
+ if min_calib_range is None or max_calib_range is None:
+ requant = mx.sym.contrib.requantize(sym_data, sym_min_range,
sym_max_range)
+ out = requant.bind(ctx=mx.current_context(),
+ args={'data':qdata, 'min_range':min_range,
+ 'max_range':max_range})
+ qdata_int8, min_output, max_output = out.forward()
+ else:
+ requant = mx.sym.contrib.requantize(sym_data, sym_min_range,
sym_max_range,
+ min_calib_range,
max_calib_range)
+ out = requant.bind(ctx=mx.current_context(), args={'data':qdata,
'min_range':min_range,
+ 'max_range':max_range})
+ qdata_int8, min_output, max_output = out.forward()
+
+ qdata_int8_np, min_output_np, max_output_np =
requantize_baseline(qdata.asnumpy(), min_range.asscalar(),
+
max_range.asscalar(),
+
min_calib_range=min_calib_range,
+
max_calib_range=max_calib_range)
+ assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np)
+ assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
+ assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))
+ # test with symbol API.
+ check_requantize_with_symbol((3, 4, 10, 10))
+ check_requantize_with_symbol((32, 3, 23, 23))
+ check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0,
max_calib_range=1040.0)
+ check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349,
max_calib_range=523.43)
+ # Test with nd array API
check_requantize((3, 4, 10, 10))
check_requantize((32, 3, 23, 23))
check_requantize((3, 4, 10, 10), min_calib_range=-1050.0,
max_calib_range=1040.0)