bgawrych commented on a change in pull request #19587: URL: https://github.com/apache/incubator-mxnet/pull/19587#discussion_r536243581
########## File path: tests/python/quantization/test_quantization.py ########## @@ -0,0 +1,1210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Some of the tests using CUDNN require a special GPU instruction called dp4a. +Ref: http://images.nvidia.com/content/pdf/tesla/184457-Tesla-P4-Datasheet-NV-Final-Letter-Web.pdf +""" +import os +import mxnet as mx +import numpy as np +from mxnet.gluon.model_zoo import vision +from mxnet.test_utils import assert_almost_equal, assert_exception, rand_ndarray, rand_shape_nd, same, DummyIter +from common import xfail_when_nonstandard_decimal_separator +from mxnet.io import NDArrayIter +import unittest +import operator + + +def initialize_block_params(block, initializer): + for name, param in block.collect_params('.*gamma|.*running_var|.*moving_var').items(): + param.initialize(mx.init.Constant(1)) + for name, param in block.collect_params('.*beta|.*bias|.*moving_mean|.*running_mean').items(): + param.initialize(mx.init.Constant(0)) + for name, param in block.collect_params('.*weight').items(): + param.initialize(initializer) + +def collect_block_args_aux(block, sym): + arg_params, aux_params = dict(), dict() + for k, v in block.collect_params().items(): + if k in sym.list_arguments(): + arg_params[k]= v._reduce() + elif k in sym.list_auxiliary_states(): + aux_params[k]= v._reduce() + return arg_params, aux_params + +def is_test_for_gpu(): + return mx.current_context().device_type == 'gpu' + + +def is_test_for_mkldnn(): + return (mx.current_context().device_type == 'cpu' + and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == '1') + + +def is_test_for_native_cpu(): + return (mx.current_context().device_type == 'cpu' + and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == None) + + +def test_quantize_float32_to_int8(): + shape = rand_shape_nd(4) + data = rand_ndarray(shape, 'default', dtype='float32') + min_range = mx.nd.min(data) + max_range = mx.nd.max(data) + qdata, min_val, max_val = mx.nd.contrib.quantize(data, min_range, max_range, out_type='int8') + data_np = data.asnumpy() + min_range = min_range.asscalar() + max_range = max_range.asscalar() + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + quantized_range = 127.0 + scale = quantized_range / real_range + assert qdata.dtype == np.int8 + assert min_val.dtype == np.float32 + assert max_val.dtype == np.float32 + assert same(min_val.asscalar(), -real_range) + assert same(max_val.asscalar(), real_range) + qdata_np = (np.sign(data_np) * np.minimum(np.abs(data_np) * scale + 0.5, quantized_range)).astype(np.int8) + assert_almost_equal(qdata.asnumpy(), qdata_np, atol = 1) + + +def test_dequantize_int8_to_float32(): + + def get_test_data(real_range, qdata_np): + qdata = mx.nd.array(qdata_np, dtype=np.int8) + min_range = mx.nd.array([-real_range], dtype=np.float32) + max_range = mx.nd.array([real_range], dtype=np.float32) + return qdata, min_range, max_range + + def baseline_dequantization(qdata, real_range, qdata_np): + quantized_range = 127.0 + scale = real_range / quantized_range + data_np = qdata_np * scale + return data_np + + def test_nd_array_dequantization(qdata, min_range, max_range, expected_result): + data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') + assert data.dtype == np.float32 + assert_almost_equal(data.asnumpy(), expected_result, atol = 1) + + def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result): + sym_data = mx.sym.Variable('data') + sym_min_range = mx.sym.Variable('min_range') + sym_max_range = mx.sym.Variable('max_range') + dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, + sym_max_range, out_type='float32') + out = dequant._bind(ctx=mx.current_context(), + args={'data':qdata, 'min_range':min_range, 'max_range':max_range}) + data = out.forward()[0] + assert data.dtype == np.float32 + assert_almost_equal(data.asnumpy(), expected_result, atol = 1) + + real_range = 128 + shape = rand_shape_nd(4) + qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8) + qdata, min_range, max_range = get_test_data(real_range, qdata_np) + expected_result = baseline_dequantization(qdata, real_range, qdata_np) + # test nd array implementation. + test_nd_array_dequantization(qdata, min_range, max_range, expected_result) + # test symbolic api implementaion. + test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result) + + +def test_requantize_int32_to_int8(): + def quantized_int32_to_float(qdata, min_range, max_range): + assert qdata.dtype == 'int32' + quantized_range = np.iinfo('int32').max + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + scale = float(real_range) / float(quantized_range) + return qdata.astype('float32') * scale + + def float_to_quantized_int8(data, min_range, max_range): + assert data.dtype == 'float32' + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + quantized_range = np.iinfo('int8').max + scale = float(quantized_range) / float(real_range) + return (np.sign(data) * np.minimum(np.abs(data) * scale + 0.5, quantized_range)).astype('int8') + + def requantize(qdata, min_data, max_data, real_range): + data = quantized_int32_to_float(qdata, min_data, max_data) + output = float_to_quantized_int8(data, -real_range, real_range) + return output, -real_range, real_range + + def requantize_baseline(qdata, min_data, max_data, min_calib_range=None, max_calib_range=None): + if min_calib_range is not None and max_calib_range is not None: + real_range = np.maximum(np.abs(min_calib_range), np.abs(max_calib_range)) + return requantize(qdata, min_data, max_data, real_range) + else: + min_range = quantized_int32_to_float(np.min(qdata), min_data, max_data) + max_range = quantized_int32_to_float(np.max(qdata), min_data, max_data) + return requantize(qdata, min_data, max_data, np.maximum(np.abs(min_range), np.abs(max_range))) + + def check_requantize(shape, min_calib_range=None, max_calib_range=None): + qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') + min_range = mx.nd.array([-1010.0]) + max_range = mx.nd.array([1020.0]) + if min_calib_range is None or max_calib_range is None: + qdata_int8, min_output, max_output = mx.nd.contrib.requantize(qdata, min_range, max_range) + else: + qdata_int8, min_output, max_output = mx.nd.contrib.requantize(qdata, min_range, max_range, + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + + qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), + max_range.asscalar(), + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1) + assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) + assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + + def check_requantize_with_symbol(shape, min_calib_range=None, max_calib_range=None): + qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') + min_range = mx.nd.array([-1010.0]) + max_range = mx.nd.array([1020.0]) + sym_data = mx.sym.Variable('data') + sym_min_range = mx.sym.Variable('min_range') + sym_max_range = mx.sym.Variable('max_range') + if min_calib_range is None or max_calib_range is None: + requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range) + out = requant._bind(ctx=mx.current_context(), + args={'data':qdata, 'min_range':min_range, + 'max_range':max_range}) + qdata_int8, min_output, max_output = out.forward() + else: + requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range, + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + out = requant._bind(ctx=mx.current_context(), args={'data':qdata, 'min_range':min_range, + 'max_range':max_range}) + qdata_int8, min_output, max_output = out.forward() + + qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), + max_range.asscalar(), + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1) + assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) + assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + + # test with symbol API. + check_requantize_with_symbol((3, 4, 10, 10)) + check_requantize_with_symbol((32, 3, 23, 23)) + check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) + check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) + # Test with nd array API + check_requantize((3, 4, 10, 10)) + check_requantize((32, 3, 23, 23)) + check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) + check_requantize((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) + + +def test_quantized_conv(): + def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, dilate, no_bias, qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_conv for native cpu since it is not supported yet') + return + elif is_test_for_mkldnn(): + # (TODO)Xinyu: https://github.com/apache/incubator-mxnet/issues/16830 + print('skipped testing quantized_conv for mkldnn cpu since it is a flaky case') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantized_conv for gpu uint8 since it is not supported yet') + return + elif is_test_for_gpu() and len(data_shape) != 4: + print('skipped testing quantized_conv for gpu 5d layout since it is not supported yet') + return + + # run fp32 conv + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + conv = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride, + dilate=dilate, no_bias=no_bias, cudnn_off=False, name='conv') + arg_shapes, _, _ = conv.infer_shape(data=data_shape) + arg_names = conv.list_arguments() + conv_exe_fp32 = conv._simple_bind(ctx=mx.current_context(), grad_req='null') + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + conv_exe_fp32.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=data_shape).astype('int32') + conv_exe_fp32.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[1]).astype('int32') + if not no_bias: + conv_exe_fp32.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[2]).astype('int32') + output = conv_exe_fp32.forward()[0] + + # run quantized conv + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) + qweight = mx.sym.Variable(name='qweight', dtype='int8') + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + min_weight = mx.sym.Variable(name='min_weight') + max_weight = mx.sym.Variable(name='max_weight') + quantized_conv = mx.sym.contrib.quantized_conv(data=qdata, weight=qweight, min_data=min_data, + max_data=max_data, min_weight=min_weight, + max_weight=max_weight, kernel=kernel, + num_filter=num_filter, pad=pad, stride=stride, + dilate=dilate, no_bias=no_bias) + qarg_names = quantized_conv.list_arguments() + type_dict = None + if not no_bias: + type_dict = {qarg_names[2]: 'int8'} + conv_exe_int8 = quantized_conv._simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') + conv_exe_int8.arg_dict[qarg_names[0]][:] = conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype) + conv_exe_int8.arg_dict[qarg_names[1]][:] = conv_exe_fp32.arg_dict[arg_names[1]].astype('int8') + quantized_range = 127.0 + if no_bias: + conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range + else: + conv_exe_int8.arg_dict[qarg_names[2]][:] = conv_exe_fp32.arg_dict[arg_names[2]].astype('int8') + conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range + qoutput, min_range, max_range = conv_exe_int8.forward() + + if no_bias: + assert_almost_equal(output.asnumpy(), qoutput.asnumpy(), atol = 1) + else: + # with adding bias, accuracy loss should not be greater than one + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + for qdtype in ['int8', 'uint8']: + check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), True, qdtype) + check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), False, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), False, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), True, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (2, 2, 2), False, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (2, 2, 2), True, qdtype) + + +def test_quantized_elemwise_add(): + def check_quantized_elemwise_add(data_shape, qtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_elemwise_add for native cpu since it is not supported yet') + return + elif qtype != 'uint8' and qtype != 'int8': + print('skipped testing quantized_elemwise_add for not supported data type') + return + elif is_test_for_gpu(): + print('skipped testing quantized_elemwise_add for gpu since it is not supported yet') + return + + dataA = mx.sym.Variable(name='dataA', shape=data_shape, dtype='float32') + dataB = mx.sym.Variable(name='dataB', shape=data_shape, dtype='float32') + elemwise_add_fp32 = mx.sym.elemwise_add(dataA, dataB) + arg_names = elemwise_add_fp32.list_arguments() + elemwise_add_fp32_exe = elemwise_add_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + if qtype == 'uint8': + data_low = 0.0 + data_high = 255.0 + else: + data_low = -127.0 + data_high = 127.0 + + dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') + dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') + elemwise_add_fp32_exe.arg_dict[arg_names[0]][:] = dataA_val + + elemwise_add_fp32_exe.arg_dict[arg_names[1]][:] = dataB_val + + output = elemwise_add_fp32_exe.forward()[0] + print(output) + qdataA = mx.sym.Variable(name='qdataA', shape=data_shape, dtype=qtype) + qdataB = mx.sym.Variable(name='qdataB', shape=data_shape, dtype=qtype) + min_dataA = mx.sym.Variable(name='min_dataA', dtype='float32') + max_dataA = mx.sym.Variable(name='max_dataA', dtype='float32') + min_dataB = mx.sym.Variable(name='min_dataB', dtype='float32') + max_dataB = mx.sym.Variable(name='max_dataB', dtype='float32') + quantized_elemwise_add = mx.sym.contrib.quantized_elemwise_add(qdataA, qdataB, min_dataA, max_dataA, min_dataB, max_dataB) + elemwise_add_int8_exe = quantized_elemwise_add._simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_elemwise_add.list_arguments() + elemwise_add_int8_exe.arg_dict[qarg_names[0]][:] = elemwise_add_fp32_exe.arg_dict[arg_names[0]].astype(qtype) + elemwise_add_int8_exe.arg_dict[qarg_names[1]][:] = elemwise_add_fp32_exe.arg_dict[arg_names[1]].astype(qtype) + quantized_range = 127.0 + elemwise_add_int8_exe.arg_dict[qarg_names[2]][:] = data_low + elemwise_add_int8_exe.arg_dict[qarg_names[3]][:] = data_high + elemwise_add_int8_exe.arg_dict[qarg_names[4]][:] = data_low + elemwise_add_int8_exe.arg_dict[qarg_names[5]][:] = data_high + qoutput, min_range, max_range = elemwise_add_int8_exe.forward() + print(qoutput) + int8_rslt = qoutput.astype(output.dtype)*max_range/0x7fffffff + print(int8_rslt) + diff = mx.nd.abs(output - int8_rslt) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + for qtype in ['int8', 'uint8']: + check_quantized_elemwise_add((4, 6), qtype) + # check_quantized_elemwise_add((13, 74, 52), qtype) Review comment: Not enabled after debugging - thank you and fixed :) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
