DominikaJedynak opened a new pull request #20816: URL: https://github.com/apache/incubator-mxnet/pull/20816
## Description ## This PR adds the possibility to fuse dequantize node with convolution node, what in practice enables us to avoid unnecessary multiplying and then dividing all entries of a convolution by the same scaling factor. Speedup on various data sizes:  Measured on instance c6i.12xlarge (Intel Xeon Platinum 8375C), ami-04505e74c0741db8d (Canonical, Ubuntu, 20.04 LTS) Script: ``` import mxnet as mx from mxnet.contrib import quantization from mxnet.gluon import nn import gc import time batch_size = [1, 3, 8, 32, 64] channels = [1, 3, 16, 64] picture_size = [32, 64, 128, 248, 512, 1024] DATA_SHAPE=[(n, c, s, s) for n in batch_size for c in channels for s in picture_size] rounds = 1000 warmup = 100 def print_header(header): print( "\n---- ", header, " ----") print(" Shape | Time [s] | Mean [ms]" ) def print_value(shape, total, mean): print("({:4}, {:4}, {:4}, {:4}) | {:8.3f} | {:8.3f} ".format( shape[0], shape[1], shape[2], shape[3], total, mean)) def measure(net, data, shape): mx.nd.waitall() gc.collect() gc.disable() tic = 0 for i in range(rounds + warmup): if i == warmup: start_time = time.time() o = net(data) o.wait_to_read() end_time = time.time() run_time = (end_time - start_time) print_value(shape, run_time, 1000 * run_time / rounds) gc.enable() gc.collect() class Conv(nn.HybridBlock): def __init__(self, **kwargs): super(Conv, self).__init__(**kwargs) self.conv0 = nn.Conv2D(channels=4, kernel_size=(3, 3), strides=1, use_bias=False) def forward(self, x): out = self.conv0(x) return out def benchmark(): for data_shape in DATA_SHAPE: net = Conv() net.initialize() net.hybridize(static_alloc=True, static_shape=True) x = mx.np.random.uniform(size=data_shape, low=-1.0, high=1.0) data = mx.gluon.data.ArrayDataset(x) calib_data = mx.gluon.data.DataLoader(data, batch_size=1) net = quantization.quantize_net(net, ctx=mx.current_context(), calib_mode='naive', calib_data=calib_data, num_calib_batches=1, ) measure(net, x, data_shape) benchmark() ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
