KellenSunderland closed pull request #13755: Do not merge: demonstrate mshadow 
size cast build.
URL: https://github.com/apache/incubator-mxnet/pull/13755
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitignore b/.gitignore
index 7eb8e7d6e77..9a145f5b8af 100644
--- a/.gitignore
+++ b/.gitignore
@@ -66,6 +66,7 @@ __pycache__
 build
 cmake-build*
 data
+model
 recommonmark
 deps
 
diff --git a/cpp-package/include/mxnet-cpp/monitor.h 
b/cpp-package/include/mxnet-cpp/monitor.h
index c1494d0bd0a..76e7ce836f1 100644
--- a/cpp-package/include/mxnet-cpp/monitor.h
+++ b/cpp-package/include/mxnet-cpp/monitor.h
@@ -70,8 +70,9 @@ class Monitor {
   /*!
   * \brief install callback to executor. Supports installing to multiple 
executors.
   * \param exe The executor to install to.
+  * \param monitor_all If true, monitor both input and output, otherwise 
monitor output only.
   */
-  void install(Executor *exe);
+  void install(Executor *exe, bool monitor_all = false);
 
   /*!
   * \brief Start collecting stats for current batch. Call before calling 
forward.
diff --git a/cpp-package/include/mxnet-cpp/monitor.hpp 
b/cpp-package/include/mxnet-cpp/monitor.hpp
index f3584e2e809..bd7f1927e90 100644
--- a/cpp-package/include/mxnet-cpp/monitor.hpp
+++ b/cpp-package/include/mxnet-cpp/monitor.hpp
@@ -43,10 +43,10 @@ inline Monitor::Monitor(int interval, std::regex pattern, 
StatFunc stat_func)
   : interval(interval), pattern(pattern), stat_func(stat_func), step(0) {
 }
 
-inline void Monitor::install(Executor *exe) {
+inline void Monitor::install(Executor *exe, bool monitor_all) {
   MXExecutorSetMonitorCallback(exe->handle_,
-      static_cast<ExecutorMonitorCallback>(&Monitor::executor_callback),
-      this);
+                               
static_cast<ExecutorMonitorCallback>(&Monitor::executor_callback),
+                               this, monitor_all);
   exes.push_back(exe);
 }
 
diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py 
b/example/quantization/imagenet_gen_qsym_mkldnn.py
index 938890bb75d..7fa7324beae 100644
--- a/example/quantization/imagenet_gen_qsym_mkldnn.py
+++ b/example/quantization/imagenet_gen_qsym_mkldnn.py
@@ -55,24 +55,24 @@ def convert_from_gluon(model_name, image_shape, 
classes=1000, logger=None):
     symnet = mx.symbol.load_json(y.tojson())
     params = net.collect_params()
     args = {}
-    auxs = {}    
+    auxs = {}
     for param in params.values():
         v = param._reduce()
         k = param.name
         if 'running' in k:
             auxs[k] = v
         else:
-            args[k] = v            
+            args[k] = v
     mod = mx.mod.Module(symbol=symnet, context=mx.cpu(),
                         label_names = ['softmax_label'])
-    mod.bind(for_training=False, 
-             data_shapes=[('data', (1,) + 
+    mod.bind(for_training=False,
+             data_shapes=[('data', (1,) +
                           tuple([int(i) for i in image_shape.split(',')]))])
     mod.set_params(arg_params=args, aux_params=auxs)
     dst_dir = os.path.join(dir_path, 'model')
     prefix = os.path.join(dir_path, 'model', model_name)
     if not os.path.isdir(dst_dir):
-        os.mkdir(dst_dir)       
+        os.mkdir(dst_dir)
     mod.save_checkpoint(prefix, 0)
     return prefix
 
@@ -104,7 +104,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
                              'you can set to custom to load your pre-trained 
model.')
     parser.add_argument('--use-gluon-model', type=bool, default=False,
                         help='If enabled, will download pretrained model from 
Gluon-CV '
-                             'and convert to symbolic model ')    
+                             'and convert to symbolic model ')
     parser.add_argument('--batch-size', type=int, default=32)
     parser.add_argument('--label-name', type=str, default='softmax_label')
     parser.add_argument('--calib-dataset', type=str, 
default='data/val_256_q90.rec',
@@ -114,7 +114,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
                         help='number of threads for data decoding')
     parser.add_argument('--num-calib-batches', type=int, default=10,
                         help='number of batches for calibration')
-    parser.add_argument('--exclude-first-conv', action='store_true', 
default=True,
+    parser.add_argument('--exclude-first-conv', action='store_true', 
default=False,
                         help='excluding quantizing the first conv layer since 
the'
                              ' input data may have negative value which 
doesn\'t support at moment' )
     parser.add_argument('--shuffle-dataset', action='store_true', default=True,
@@ -140,8 +140,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
                              ' thresholds. This mode is expected to produce 
the best inference accuracy of all three'
                              ' kinds of quantized models if the calibration 
dataset is representative enough of the'
                              ' inference dataset.')
-    parser.add_argument('--quantized-dtype', type=str, default='uint8',
-                        choices=['int8', 'uint8'],
+    parser.add_argument('--quantized-dtype', type=str, default='auto',
+                        choices=['auto', 'int8', 'uint8'],
                         help='quantization destination data type for input 
data')
     parser.add_argument('--enable-calib-quantize', type=bool, default=True,
                         help='If enabled, the quantize op will '
@@ -203,35 +203,30 @@ def save_params(fname, arg_params, aux_params, 
logger=None):
     if args.model == 'imagenet1k-resnet-152':
         rgb_mean = '0,0,0'
         rgb_std = '1,1,1'
-        calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['flatten0', 'fc1', 'pooling0']
+        excluded_sym_names += ['flatten0', 'fc1']
         if exclude_first_conv:
             excluded_sym_names += ['conv0']
     elif args.model == 'imagenet1k-inception-bn':
         rgb_mean = '123.68,116.779,103.939'
         rgb_std = '1,1,1'
-        calib_layer = lambda name: name.endswith('_output')
         excluded_sym_names += ['flatten', 'fc1']
         if exclude_first_conv:
             excluded_sym_names += ['conv_1']
     elif args.model in ['resnet50_v1', 'resnet101_v1']:
         rgb_mean = '123.68,116.779,103.939'
         rgb_std = '58.393, 57.12, 57.375'
-        calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['resnetv10_dense0_fwd', 'resnetv10_pool0_fwd']
+        excluded_sym_names += ['resnetv10_dense0_fwd']
         if exclude_first_conv:
             excluded_sym_names += ['resnetv10_conv0_fwd']
     elif args.model == 'squeezenet1.0':
         rgb_mean = '123.68,116.779,103.939'
         rgb_std = '58.393, 57.12, 57.375'
-        calib_layer = lambda name: name.endswith('_output')
         excluded_sym_names += ['squeezenet0_flatten0_flatten0']
         if exclude_first_conv:
             excluded_sym_names += ['squeezenet0_conv0_fwd']
     elif args.model == 'mobilenet1.0':
         rgb_mean = '123.68,116.779,103.939'
         rgb_std = '58.393, 57.12, 57.375'
-        calib_layer = lambda name: name.endswith('_output')
         excluded_sym_names += ['mobilenet0_flatten0_flatten0',
                                'mobilenet0_dense0_fwd',
                                'mobilenet0_pool0_fwd']
@@ -240,16 +235,13 @@ def save_params(fname, arg_params, aux_params, 
logger=None):
     elif args.model == 'inceptionv3':
         rgb_mean = '123.68,116.779,103.939'
         rgb_std = '58.393, 57.12, 57.375'
-        calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['inception30_dense0_fwd',
-                               'inception30_pool0_fwd']
+        excluded_sym_names += ['inception30_dense0_fwd']
         if exclude_first_conv:
             excluded_sym_names += ['inception30_conv0_fwd']
     elif args.model == 'custom':
         # add rgb mean/std of your model.
         rgb_mean = '0,0,0'
         rgb_std = '0,0,0'
-        calib_layer = lambda name: name.endswith('_output')
         # add layer names you donnot want to quantize.
         # add conv/pool layer names that has negative inputs
         # since Intel MKL-DNN only support uint8 quantization temporary.
@@ -302,9 +294,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
                                                         ctx=ctx, 
excluded_sym_names=excluded_sym_names,
                                                         calib_mode=calib_mode, 
calib_data=data,
                                                         
num_calib_examples=num_calib_batches * batch_size,
-                                                        
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
-                                                        
label_names=(label_name,), calib_quantize_op = True,
-                                                        logger=logger)
+                                                        calib_layer=None, 
quantized_dtype=args.quantized_dtype,
+                                                        
label_names=(label_name,), logger=logger)
         if calib_mode == 'entropy':
             suffix = '-quantized-%dbatches-entropy' % num_calib_batches
         elif calib_mode == 'naive':
diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py
index 231cc99f93b..8cdde894dc2 100644
--- a/example/ssd/quantization.py
+++ b/example/ssd/quantization.py
@@ -51,7 +51,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
     parser.add_argument('--batch-size', type=int, default=32)
     parser.add_argument('--num-calib-batches', type=int, default=5,
                         help='number of batches for calibration')
-    parser.add_argument('--exclude-first-conv', action='store_true', 
default=True,
+    parser.add_argument('--exclude-first-conv', action='store_true', 
default=False,
                         help='excluding quantizing the first conv layer since 
the'
                              ' number of channels is usually not a multiple of 
4 in that layer'
                              ' which does not satisfy the requirement of 
cuDNN')
@@ -78,8 +78,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
                              ' thresholds. This mode is expected to produce 
the best inference accuracy of all three'
                              ' kinds of quantized models if the calibration 
dataset is representative enough of the'
                              ' inference dataset.')
-    parser.add_argument('--quantized-dtype', type=str, default='uint8',
-                        choices=['int8', 'uint8'],
+    parser.add_argument('--quantized-dtype', type=str, default='auto',
+                        choices=['auto', 'int8', 'uint8'],
                         help='quantization destination data type for input 
data')
 
     args = parser.parse_args()
@@ -119,12 +119,9 @@ def save_params(fname, arg_params, aux_params, 
logger=None):
     exclude_first_conv = args.exclude_first_conv
     excluded_sym_names = []
     rgb_mean = '123,117,104'
-    calib_layer = lambda name: name.endswith('_output')
     for i in range(1,19):
         excluded_sym_names += ['flatten'+str(i)]
-    excluded_sym_names += ['relu4_3_cls_pred_conv',
-                            'relu7_cls_pred_conv',
-                            'relu4_3_loc_pred_conv']
+
     if exclude_first_conv:
         excluded_sym_names += ['conv1_1']
 
@@ -156,10 +153,8 @@ def save_params(fname, arg_params, aux_params, 
logger=None):
                                                         ctx=ctx, 
excluded_sym_names=excluded_sym_names,
                                                         calib_mode=calib_mode, 
calib_data=eval_iter,
                                                         
num_calib_examples=num_calib_batches * batch_size,
-                                                        
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
-                                                        
label_names=(label_name,),
-                                                        calib_quantize_op = 
True,
-                                                        logger=logger)
+                                                        calib_layer=None, 
quantized_dtype=args.quantized_dtype,
+                                                        
label_names=(label_name,), logger=logger)
         sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300')
         param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', 
epoch)
     qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index e9f1e2d6ccc..1c7575ccd68 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1556,13 +1556,12 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
  * \param num_offline number of parameters that are quantized offline
  * \param offline_params array of c strings representing the names of params 
quantized offline
  * \param quantized_dtype the quantized destination type for input data.
- * \param calib_quantize whether calibrate quantize op with offline 
calibration data.
  */
 MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle 
*ret_sym_handle,
                                const mx_uint num_excluded_symbols,
                                const char **excluded_symbols,
                                const mx_uint num_offline, const char 
**offline_params,
-                               const char *quantized_dtype, const bool 
calib_quantize);
+                               const char *quantized_dtype);
 
 /*!
  * \brief Set calibration table to node attributes in the sym
@@ -1833,10 +1832,12 @@ MXNET_DLL int 
MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
 
 /*!
  * \brief set a call back to notify the completion of operation
+ * \param monitor_all If true, monitor both input and output, otherwise 
monitor output only.
  */
 MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                            ExecutorMonitorCallback callback,
-                                           void* callback_handle);
+                                           void* callback_handle,
+                                           bool monitor_all);
 //--------------------------------------------
 // Part 5: IO Interface
 //--------------------------------------------
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index 0ab04b86a0a..877b1300e26 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -174,7 +174,7 @@ class Executor {
   /*!
    * \brief Install a callback to notify the completion of operation.
    */
-  virtual void SetMonitorCallback(const MonitorCallback& callback) {}
+  virtual void SetMonitorCallback(const MonitorCallback& callback, bool 
monitor_all) {}
 };  // class executor
 }  // namespace mxnet
 #endif  // MXNET_EXECUTOR_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 4ba13ca6498..5de42e19a65 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -694,9 +694,13 @@ class NDArray {
   /*
    * Create NDArray from mkldnn memory.
    * mkldnn_mem The mkldnn memory to be managed.
-   * static_data If true, mkldnn memory won't be freed on destruction.
    */
-  explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
+  explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
+  /*
+   * Create NDArray from mkldnn memory descriptor.
+   * mem_pd The mkldnn memory descriptor to be created.
+   */
+  explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
   /*
    * Test if the data is stored in one of special MKLDNN format.
    */
@@ -776,7 +780,7 @@ class NDArray {
    /*!
    * \ Fix mkldnn memory descriptor mismatch from NDArray.
    */
-  void UpdateMKLDNNMemDesc();
+  void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
 #endif
 
   /*!
diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index 496e8c7cfce..412877a5821 100755
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -287,7 +287,7 @@ class TBlob {
     CHECK(Device::kDevMask == this->dev_mask())
       << "TBlob.get: device type do not match specified type";
     CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be 
contiguous";
-    CHECK_EQ(this->shape_.Size(), shape.Size())
+    CHECK_EQ(this->shape_.Size(), static_cast<size_t>(shape.Size()))
       << "TBlob.get_with_shape: new and old shape do not match total elements";
     return mshadow::Tensor<Device, dim, DType>(dptr<DType>(), shape,
                                                shape[dim - 1], stream);
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i 
b/perl-package/AI-MXNetCAPI/mxnet.i
index b1907f5cd7e..ca6623572df 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -1614,10 +1614,12 @@ int MXExecutorReshape(int partial_shaping,
 
 /*!
  * \brief set a call back to notify the completion of operation
+ * \param monitor_all If true, monitor both input and output, otherwise 
monitor output only.
  */
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                            ExecutorMonitorCallback callback,
-                                           void* callback_handle);
+                                           void* callback_handle,
+                                           bool monitor_all);
 //--------------------------------------------
 // Part 5: IO Interface
 //--------------------------------------------
@@ -2167,4 +2169,3 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int 
dev_id, void** cuda_kernel_
                                   mx_uint grid_dim_z, mx_uint block_dim_x,
                                   mx_uint block_dim_y, mx_uint block_dim_z,
                                   mx_uint shared_mem);
-
diff --git a/python/mxnet/contrib/quantization.py 
b/python/mxnet/contrib/quantization.py
index 61ad8a3ec70..0f32cbc82f3 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -26,6 +26,7 @@
 import ctypes
 import logging
 import os
+import sys
 import numpy as np
 from ..base import _LIB, check_call, py_str
 from ..base import c_array, c_str, mx_uint, c_str_array
@@ -80,8 +81,7 @@ def _quantize_params(qsym, params, th_dict):
                 quantized_params[name] = ndarray.array([th_dict[output][1]])
     return quantized_params
 
-def _quantize_symbol(sym, excluded_symbols=None, offline_params=None,
-                     quantized_dtype='int8', calib_quantize_op=False):
+def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, 
quantized_dtype='int8'):
     """Given a symbol object representing a neural network of data type FP32,
     quantize it into a INT8 network.
 
@@ -98,8 +98,6 @@ def _quantize_symbol(sym, excluded_symbols=None, 
offline_params=None,
         avoided.
     quantized_dtype: str
         The quantized destination type for input data.
-    calib_quantize_op : bool
-        Whether perform offline calibration for quantize op.
     """
     num_excluded_symbols = 0
     if excluded_symbols is not None:
@@ -122,8 +120,7 @@ def _quantize_symbol(sym, excluded_symbols=None, 
offline_params=None,
                                      c_str_array(excluded_symbols),
                                      mx_uint(num_offline),
                                      c_array(ctypes.c_char_p, offline),
-                                     c_str(quantized_dtype),
-                                     ctypes.c_bool(calib_quantize_op)))
+                                     c_str(quantized_dtype)))
     return Symbol(out)
 
 
@@ -139,18 +136,20 @@ def __init__(self, include_layer=None, logger=None):
 
     def collect(self, name, arr):
         """Callback function for collecting layer output NDArrays."""
-        name = py_str(name)
-        if self.include_layer is not None and not self.include_layer(name):
-            return
-        handle = ctypes.cast(arr, NDArrayHandle)
-        arr = NDArray(handle, writable=False).copyto(cpu())
-        if self.logger is not None:
-            self.logger.info("Collecting layer %s output of shape %s" % (name, 
arr.shape))
-        if name in self.nd_dict:
-            self.nd_dict[name].append(arr)
-        else:
-            self.nd_dict[name] = [arr]
-
+        try:
+            name = py_str(name)
+            if self.include_layer is not None and not self.include_layer(name):
+                return
+            handle = ctypes.cast(arr, NDArrayHandle)
+            arr = NDArray(handle, writable=False).copyto(cpu())
+            if self.logger is not None:
+                self.logger.info("Collecting layer %s output of shape %s" % 
(name, arr.shape))
+            if name in self.nd_dict:
+                self.nd_dict[name].append(arr)
+            else:
+                self.nd_dict[name] = [arr]
+        except KeyboardInterrupt:
+            sys.exit(1)
 
 class _LayerOutputMinMaxCollector(object):
     """Saves layer output min and max values in a dict with layer names as 
keys.
@@ -163,23 +162,25 @@ def __init__(self, include_layer=None, logger=None):
 
     def collect(self, name, arr):
         """Callback function for collecting min and max values from an 
NDArray."""
-        name = py_str(name)
-        if self.include_layer is not None and not self.include_layer(name):
-            return
-        handle = ctypes.cast(arr, NDArrayHandle)
-        arr = NDArray(handle, writable=False)
-        min_range = ndarray.min(arr).asscalar()
-        max_range = ndarray.max(arr).asscalar()
-        if name in self.min_max_dict:
-            cur_min_max = self.min_max_dict[name]
-            self.min_max_dict[name] = (min(cur_min_max[0], min_range),
-                                       max(cur_min_max[1], max_range))
-        else:
-            self.min_max_dict[name] = (min_range, max_range)
-        if self.logger is not None:
-            self.logger.info("Collecting layer %s output min_range=%f, 
max_range=%f"
-                             % (name, min_range, max_range))
-
+        try:
+            name = py_str(name)
+            if self.include_layer is not None and not self.include_layer(name):
+                return
+            handle = ctypes.cast(arr, NDArrayHandle)
+            arr = NDArray(handle, writable=False)
+            min_range = ndarray.min(arr).asscalar()
+            max_range = ndarray.max(arr).asscalar()
+            if name in self.min_max_dict:
+                cur_min_max = self.min_max_dict[name]
+                self.min_max_dict[name] = (min(cur_min_max[0], min_range),
+                                           max(cur_min_max[1], max_range))
+            else:
+                self.min_max_dict[name] = (min_range, max_range)
+            if self.logger is not None:
+                self.logger.info("Collecting layer %s min_range=%f, 
max_range=%f"
+                                 % (name, min_range, max_range))
+        except KeyboardInterrupt:
+            sys.exit(1)
 
 def _calibrate_quantized_sym(qsym, th_dict):
     """Given a dictionary containing the thresholds for quantizing the layers,
@@ -210,7 +211,7 @@ def _collect_layer_statistics(mod, data, collector, 
max_num_examples=None, logge
     if not isinstance(data, DataIter):
         raise ValueError('Only supports data as a type of DataIter, while 
received type %s'
                          % str(type(data)))
-    mod._exec_group.execs[0].set_monitor_callback(collector.collect)
+    mod._exec_group.execs[0].set_monitor_callback(collector.collect, 
monitor_all=True)
     num_batches = 0
     num_examples = 0
     for batch in data:
@@ -265,6 +266,9 @@ def _smooth_distribution(p, eps=0.0001):
 # pylint: disable=line-too-long
 def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
     """Given a dataset, find the optimal threshold for quantizing it.
+    The reference distribution is `q`, and the candidate distribution is `p`.
+    `q` is a truncated version of the original distribution.
+
     Ref: 
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
     """
     if isinstance(arr, NDArray):
@@ -286,12 +290,12 @@ def _get_optimal_threshold(arr, num_bins=8001, 
num_quantized_bins=255):
     min_val = np.min(arr)
     max_val = np.max(arr)
     th = max(abs(min_val), abs(max_val))
+    if min_val >= 0:
+        num_quantized_bins = (num_quantized_bins // 2) * 4 + 1
 
     hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
     zero_bin_idx = num_bins // 2
     num_half_quantized_bins = num_quantized_bins // 2
-    assert np.allclose(hist_edges[zero_bin_idx] + hist_edges[zero_bin_idx + 1],
-                       0, rtol=1e-5, atol=1e-7)
 
     thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
     divergence = np.zeros_like(thresholds)
@@ -315,10 +319,10 @@ def _get_optimal_threshold(arr, num_bins=8001, 
num_quantized_bins=255):
         right_outlier_count = np.sum(hist[p_bin_idx_stop:])
         p[-1] += right_outlier_count
         # is_nonzeros[k] indicates whether hist[k] is nonzero
-        is_nonzeros = (sliced_nd_hist != 0).astype(np.int32)
+        is_nonzeros = (p != 0).astype(np.int32)
 
         # calculate how many bins should be merged to generate quantized 
distribution q
-        num_merged_bins = p.size // num_quantized_bins
+        num_merged_bins = sliced_nd_hist.size // num_quantized_bins
         # merge hist into num_quantized_bins bins
         for j in range(num_quantized_bins):
             start = j * num_merged_bins
@@ -326,17 +330,17 @@ def _get_optimal_threshold(arr, num_bins=8001, 
num_quantized_bins=255):
             quantized_bins[j] = sliced_nd_hist[start:stop].sum()
         quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * 
num_merged_bins:].sum()
         # expand quantized_bins into p.size bins
-        q = np.zeros(p.size, dtype=np.float32)
+        q = np.zeros(sliced_nd_hist.size, dtype=np.float32)
         for j in range(num_quantized_bins):
             start = j * num_merged_bins
             if j == num_quantized_bins - 1:
-                stop = -1
+                stop = len(is_nonzeros)
             else:
                 stop = start + num_merged_bins
             norm = is_nonzeros[start:stop].sum()
             if norm != 0:
                 q[start:stop] = float(quantized_bins[j]) / float(norm)
-        q[sliced_nd_hist == 0] = 0
+        q[p == 0] = 0
         p = _smooth_distribution(p)
         # There is a chance that q is an invalid probability distribution.
         try:
@@ -344,7 +348,6 @@ def _get_optimal_threshold(arr, num_bins=8001, 
num_quantized_bins=255):
         except ValueError:
             divergence[i - num_half_quantized_bins] = float("inf")
         divergence[i - num_half_quantized_bins] = stats.entropy(p, q)
-        quantized_bins[:] = 0
 
     min_divergence_idx = np.argmin(divergence)
     min_divergence = divergence[min_divergence_idx]
@@ -424,7 +427,7 @@ def quantize_model(sym, arg_params, aux_params,
                    data_names=('data',), label_names=('softmax_label',),
                    ctx=cpu(), excluded_sym_names=None, calib_mode='entropy',
                    calib_data=None, num_calib_examples=None, calib_layer=None,
-                   quantized_dtype='int8', calib_quantize_op=False, 
logger=logging):
+                   quantized_dtype='int8', logger=logging):
     """User-level API for generating a quantized model from a FP32 model w/ or 
w/o calibration.
     The backend quantized operators are only enabled for Linux systems. Please 
do not run
     inference using the quantized models on Windows for now.
@@ -476,9 +479,8 @@ def quantize_model(sym, arg_params, aux_params,
         all the layers' outputs that need requantization will be collected.
     quantized_dtype : str
         The quantized destination type for input data. Currently support 'int8'
-        and 'uint8', default value is 'int8'.
-    calib_quantize_op: bool
-        Whether calibrate quantize op with its input calibration data. The 
quantize op's input should be in calib_layer
+        , 'uint8' and 'auto'. 'auto' means automatically select output type 
according to calibration result.
+        Default value is 'int8'.
     logger : Object
         A logging object for printing information during the process of 
quantization.
 
@@ -496,13 +498,12 @@ def quantize_model(sym, arg_params, aux_params,
                          ' while received type %s' % 
str(type(excluded_sym_names)))
 
     logger.info('Quantizing symbol')
-    if quantized_dtype not in ('int8', 'uint8'):
+    if quantized_dtype not in ('int8', 'uint8', 'auto'):
         raise ValueError('unknown quantized_dtype %s received,'
-                         ' expected `int8` or `uint8`' % quantized_dtype)
+                         ' expected `int8`, `uint8` or `auto`' % 
quantized_dtype)
     qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names,
                             offline_params=list(arg_params.keys()),
-                            quantized_dtype=quantized_dtype,
-                            calib_quantize_op=calib_quantize_op)
+                            quantized_dtype=quantized_dtype)
 
     th_dict = {}
     if calib_mode is not None and calib_mode != 'none':
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index fcd5406236e..ddb2dab1098 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -234,13 +234,15 @@ def backward(self, out_grads=None, is_train=True):
             ndarray,
             ctypes.c_int(is_train)))
 
-    def set_monitor_callback(self, callback):
+    def set_monitor_callback(self, callback, monitor_all=False):
         """Install callback for monitor.
 
         Parameters
         ----------
         callback : function
             Takes a string and an NDArrayHandle.
+        monitor_all : bool, default False
+            If true, monitor both input and output, otherwise monitor output 
only.
 
         Examples
         --------
@@ -254,7 +256,8 @@ def set_monitor_callback(self, callback):
         check_call(_LIB.MXExecutorSetMonitorCallback(
             self.handle,
             self._monitor_callback,
-            None))
+            None,
+            ctypes.c_int(monitor_all)))
 
     @property
     def arg_dict(self):
diff --git a/python/mxnet/monitor.py b/python/mxnet/monitor.py
index e3185a1281a..2e10708e72f 100644
--- a/python/mxnet/monitor.py
+++ b/python/mxnet/monitor.py
@@ -31,7 +31,7 @@
 
 
 class Monitor(object):
-    """Monitor outputs, weights, and gradients for debugging.
+    """Monitor inputs, outputs, weights, and gradients for debugging.
 
     Parameters
     ----------
@@ -46,8 +46,10 @@ class Monitor(object):
         Only tensors with names that match `name_pattern` will be included.
         For example, '.*weight|.*output' will print all weights and outputs and
         '.*backward.*' will print all gradients.
+    monitor_all : bool, default False
+        If true, monitor both input and output, otherwise monitor output only.
     """
-    def __init__(self, interval, stat_func=None, pattern='.*', sort=False):
+    def __init__(self, interval, stat_func=None, pattern='.*', sort=False, 
monitor_all=False):
         if stat_func is None:
             def asum_stat(x):
                 """returns |x|/size(x), async execution."""
@@ -61,6 +63,7 @@ def asum_stat(x):
         self.exes = []
         self.re_prog = re.compile(pattern)
         self.sort = sort
+        self.monitor_all = monitor_all
         def stat_helper(name, array):
             """wrapper for executor callback"""
             array = ctypes.cast(array, NDArrayHandle)
@@ -79,7 +82,7 @@ def install(self, exe):
         exe : mx.executor.Executor
             The Executor (returned by symbol.bind) to install to.
         """
-        exe.set_monitor_callback(self.stat_helper)
+        exe.set_monitor_callback(self.stat_helper, self.monitor_all)
         self.exes.append(exe)
 
     def tic(self):
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index 17d166eac34..663d3a4142f 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -915,7 +915,8 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallbac
   jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj);
   return 
MXExecutorSetMonitorCallback(reinterpret_cast<ExecutorHandle>(executorPtr),
                                       ExecutorMonitorCallbackFunc,
-                                      reinterpret_cast<void 
*>(callbackFuncObjGlb));
+                                      reinterpret_cast<void 
*>(callbackFuncObjGlb),
+                                      false);
 }
 
 JNIEXPORT jstring JNICALL Java_org_apache_mxnet_LibInfo_mxGetLastError(JNIEnv 
* env, jobject obj) {
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index e2e53c7261f..b15f2d50864 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -649,7 +649,8 @@ int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
 
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                  ExecutorMonitorCallback callback,
-                                 void* callback_handle) {
+                                 void* callback_handle,
+                                 bool monitor_all) {
   API_BEGIN();
   ExecutorMonitorCallback callback_temp = callback;
   void* callback_handle_temp = callback_handle;
@@ -658,6 +659,6 @@ int MXExecutorSetMonitorCallback(ExecutorHandle handle,
     callback_temp(name, handle, callback_handle_temp);
   };
   Executor *exec = static_cast<Executor*>(handle);
-  exec->SetMonitorCallback(clbk);
+  exec->SetMonitorCallback(clbk, monitor_all);
   API_END();
 }
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 73a8a7ca6f8..0a49b88e542 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -650,8 +650,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
                      const char **excluded_op_names,
                      const mx_uint num_offline,
                      const char **offline_params,
-                     const char *quantized_dtype,
-                     const bool calib_quantize) {
+                     const char *quantized_dtype) {
   nnvm::Symbol *s = new nnvm::Symbol();
   API_BEGIN();
   nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
@@ -668,7 +667,6 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
   g.attrs["excluded_nodes"] = 
std::make_shared<nnvm::any>(std::move(excluded_node_names));
   g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
   g.attrs["quantized_dtype"] = 
std::make_shared<nnvm::any>(std::move(quantized_type));
-  g.attrs["calib_quantize"] = std::make_shared<nnvm::any>(calib_quantize);
   g = ApplyPass(std::move(g), "QuantizeGraph");
   s->outputs = g.outputs;
   *ret_sym_handle = s;
@@ -685,10 +683,9 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle 
qsym_handle,
   API_BEGIN();
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(qsym_handle);
   nnvm::Graph g = Symbol2Graph(*sym);
-  const std::string prefix = "quantized_";
   std::unordered_map<std::string, std::pair<float, float>> calib_table;
   for (size_t i = 0; i < num_layers; ++i) {
-    calib_table.emplace(prefix+layer_names[i], std::make_pair(min_ranges[i], 
max_ranges[i]));
+    calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], 
max_ranges[i]));
   }
   g.attrs["calib_table"] = std::make_shared<nnvm::any>(std::move(calib_table));
   g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph");
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index d866ad13557..8302dc133c6 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -101,9 +101,10 @@ void GraphExecutor::Print(std::ostream &os) const {  // 
NOLINT(*)
   os << "Total " << 11 << " TempSpace resource requested\n";
 }
 
-void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback) {
+void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback, bool 
monitor_all) {
   CHECK(callback) << "invalid callback";
   monitor_callback_ = callback;
+  monitor_all_ = monitor_all;
 }
 
 const std::vector<NDArray>& GraphExecutor::outputs() const {
@@ -1291,7 +1292,36 @@ void GraphExecutor::BulkInferenceOpSegs() {
   }
 }
 
-void GraphExecutor::ExecuteMonCallback(size_t nid) {
+void GraphExecutor::ExecuteMonInputCallback(size_t nid) {
+  static const auto& flist_inputs =
+      nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
+  const auto& idx = graph_.indexed_graph();
+  std::vector<std::string> input_names;
+  OpNode& opnode = op_nodes_[nid];
+  const auto& inode = idx[nid];
+  const auto& node = idx[nid].source;
+  if (flist_inputs.count(node->op())) {
+    input_names = flist_inputs[node->op()](node->attrs);
+  } else {
+    for (size_t i = 0; i < node->num_inputs(); ++i) {
+      input_names.emplace_back("input" + std::to_string(i));
+    }
+  }
+  CHECK_EQ(opnode.exec->in_array.size(), input_names.size());
+  for (size_t i = 0; i < opnode.exec->in_array.size(); ++i) {
+    if (node->inputs[i].node->is_variable()) {
+    // Monitor variable
+    NDArray *cpy = new NDArray(opnode.exec->in_array[i]);
+    std::string name = node->inputs[i].node->attrs.name;
+    this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
+    }
+    NDArray *cpy = new NDArray(opnode.exec->in_array[i]);
+    std::string name = inode.source->attrs.name + "_" + input_names[i];
+    this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
+  }
+}
+
+void GraphExecutor::ExecuteMonOutputCallback(size_t nid) {
   static const auto& flist_outputs =
       nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
   const auto& idx = graph_.indexed_graph();
@@ -1341,6 +1371,10 @@ void GraphExecutor::RunOps(bool is_train, size_t 
topo_start, size_t topo_end) {
     if (inode.source->is_variable()) continue;
     OpNode& opnode = op_nodes_[nid];
     if (op_nodes_[nid].skip_exec_node) continue;
+    // Monitor callbacks
+    if (monitor_callback_ && monitor_all_) {
+      ExecuteMonInputCallback(nid);
+    }
     opnode.exec->op_ctx.is_train = is_train;
     opnode.exec->op_ctx.need_grad = need_grad_;
     if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
@@ -1359,7 +1393,7 @@ void GraphExecutor::RunOps(bool is_train, size_t 
topo_start, size_t topo_end) {
     }
     // Monitor callbacks
     if (monitor_callback_) {
-      ExecuteMonCallback(nid);
+      ExecuteMonOutputCallback(nid);
     }
   }
 }
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index f5f032e3f2e..722714716aa 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -68,7 +68,7 @@ class GraphExecutor : public Executor {
   const std::unordered_map<std::string, NDArray>& arg_grad_map() const 
override;
   const std::unordered_map<std::string, NDArray>& aux_state_map() const 
override;
   void Print(std::ostream &os) const override; // NOLINT(*)
-  void SetMonitorCallback(const MonitorCallback& callback) override;
+  void SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) 
override;
   // Initialize the rest of attributes
   // after setting up arguments.
   void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g,
@@ -209,8 +209,10 @@ class GraphExecutor : public Executor {
    *  ret.opr Can be nullptr if creation failed.
   */
   CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end);
-  // run the monitor callback for node `nid`
-  void ExecuteMonCallback(size_t nid);
+  // run the monitor callback for input of node `nid`
+  void ExecuteMonInputCallback(size_t nid);
+  // run the monitor callback for output of node `nid`
+  void ExecuteMonOutputCallback(size_t nid);
   // peform bulking and segmentation on an inference graph
   void BulkInferenceOpSegs();
   // perform bulking and segmentation on a training graph
@@ -250,6 +252,8 @@ class GraphExecutor : public Executor {
   size_t num_forward_nodes_{0};
   // monitor call back
   std::function<void(const char*, void*)> monitor_callback_{nullptr};
+  // monitor both input and output from monitor call back
+  bool monitor_all_{false};
   // whether to enable bulk execution
   bool prefer_bulk_execution_;
   // cached segment operator
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 5a4cb29bc21..da470257d01 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
 
 #if MXNET_USE_MKLDNN == 1
 
-NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data)
+NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
     : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
-  auto mem_pd = mkldnn_mem->get_primitive_desc();
   auto mem_desc = mem_pd.desc();
   shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + 
mem_desc.data.ndims);
   dtype_ = get_mxnet_type(mem_desc.data.data_type);
-  auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, 
dtype_);
-  ptr_ = std::make_shared<Chunk>(data, 0);
+  ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+  ptr_->CheckAndAlloc(mem_pd.get_size());
   ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mem_pd, ptr_->shandle.dptr);
-  ptr_->static_data = static_data;
+}
+
+NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
+    : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
+  auto mem_pd = mkldnn_mem->get_primitive_desc();
+  auto mem_desc = mem_pd.desc();
+  shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + 
mem_desc.data.ndims);
+  dtype_ = get_mxnet_type(mem_desc.data.data_type);
+  ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+  ptr_->shandle.dptr = mkldnn_mem->get_data_handle();
+  ptr_->shandle.size = mem_pd.get_size();
+  ptr_->delay_alloc = false;
+  ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mkldnn_mem);
+  ptr_->static_data = true;
 }
 
 NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
@@ -716,19 +728,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const 
mkldnn::memory::primitive_desc &
   return ptr_->mkl_mem_->GetRaw();
 }
 
-void NDArray::UpdateMKLDNNMemDesc() {
+void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) {
   const mkldnn::memory *mem = GetMKLDNNData();
   auto mem_desc = mem->get_primitive_desc().desc();
   auto this_dtype = get_mkldnn_type(dtype());
-  if (this_dtype != mem_desc.data.data_type) {
-    mkldnn::memory::desc data_md(
-        mkldnn::memory::dims(mem_desc.data.dims,
-                             mem_desc.data.dims + mem_desc.data.ndims),
-        this_dtype, static_cast<mkldnn::memory::format>(mem_desc.data.format));
-    mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
-    ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
-    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
-  }
+  mkldnn::memory::desc data_md(
+      mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + 
mem_desc.data.ndims),
+      this_dtype, format);
+  mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
+  ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
 }
 #endif
 
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 17e74094c2b..660a27d8be6 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -189,6 +189,9 @@ static int GetTypeSize(int dtype) {
 }
 
 static inline size_t GetArraySize(const NDArray &arr) {
+  if (arr.IsMKLDNNData()) {
+    return arr.GetMKLDNNData()->get_primitive_desc().get_size();
+  }
   return arr.shape().Size() * GetTypeSize(arr.dtype());
 }
 
@@ -237,21 +240,20 @@ static inline size_t GetMemDescSize(const 
mkldnn::memory::desc &md) {
   return ret;
 }
 
-inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int ndim) {
+inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = 
-1) {
+  int ndim = arr.shape().ndim();
   mkldnn::memory::dims dims(ndim);
+  dtype = (dtype == -1) ? arr.dtype() : dtype;
   for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i];
-  return mkldnn::memory::desc{dims, get_mkldnn_type(arr.dtype()),
-                              mkldnn::memory::format::any};
-}
-
-inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr) {
-  return GetMemDesc(arr, arr.shape().ndim());
+  return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), 
mkldnn::memory::format::any};
 }
 
 inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
-                                                 int num_groups) {
+                                                 int num_groups,
+                                                 bool quantized = false) {
+  int dtype = quantized ? mshadow::kInt8 : arr.dtype();
   if (num_groups == 1) {
-    return GetMemDesc(arr);
+    return GetMemDesc(arr, dtype);
   } else {
     CHECK_EQ(arr.shape().ndim(), 4U);
     mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
@@ -259,7 +261,7 @@ inline static mkldnn::memory::desc GetWeightDesc(const 
NDArray &arr,
       static_cast<int>(arr.shape()[1]),
       static_cast<int>(arr.shape()[2]),
       static_cast<int>(arr.shape()[3])};
-    return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
+    return mkldnn::memory::desc{tz, get_mkldnn_type(dtype),
                                 mkldnn::memory::format::any};
   }
 }
@@ -437,6 +439,8 @@ static inline void CreateDefaultInputs(const 
std::vector<NDArray> &arrs,
   }
 }
 
+const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups);
+
 const mkldnn::memory *GetWeights(const NDArray &arr,
                                  const mkldnn::memory::primitive_desc 
&target_pd,
                                  int num_groups);
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index 5da55f4ca70..0cf44d3d78f 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -229,51 +229,44 @@ void CommitOutput(const NDArray &arr, const 
mkldnn_output_t &res) {
   }
 }
 
-const mkldnn::memory *GetWeights(const NDArray &arr,
-                                 const mkldnn::memory::primitive_desc 
&target_pd,
-                                 int num_groups) {
-  const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd);
-  // If the weight array already uses the target layout, simply return it
-  // directly.
-  if (mem)
-    return mem;
-
+const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
   mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype());
+  const mkldnn::memory *mem = nullptr;
   auto engine = CpuEngine::Get()->get_engine();
   if (arr.shape().ndim() == 2) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{
-      static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
+    mkldnn::memory::dims tz =
+        mkldnn::memory::dims{static_cast<int>(arr.shape()[0]), 
static_cast<int>(arr.shape()[1])};
+    mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, 
mkldnn::memory::format::oi};
+    mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, 
engine};
     mem = arr.GetMKLDNNData(pd);
   } else if (arr.shape().ndim() == 4 && num_groups == 1) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{
-      static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1]),
-          static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
+    mkldnn::memory::dims tz =
+        mkldnn::memory::dims{static_cast<int>(arr.shape()[0]), 
static_cast<int>(arr.shape()[1]),
+                             static_cast<int>(arr.shape()[2]), 
static_cast<int>(arr.shape()[3])};
+    mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, 
mkldnn::memory::format::oihw};
+    mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, 
engine};
     mem = arr.GetMKLDNNData(pd);
   } else if (arr.shape().ndim() == 4) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
-      static_cast<int>(arr.shape()[0] / num_groups),
-      static_cast<int>(arr.shape()[1]),
-      static_cast<int>(arr.shape()[2]),
-      static_cast<int>(arr.shape()[3])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
+    mkldnn::memory::dims tz = mkldnn::memory::dims{
+        num_groups, static_cast<int>(arr.shape()[0] / num_groups), 
static_cast<int>(arr.shape()[1]),
+        static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
+    mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, 
mkldnn::memory::format::goihw};
+    mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, 
engine};
     mem = arr.GetMKLDNNData(pd);
   } else {
     LOG(FATAL) << "The weight array has an unsupported number of dimensions";
-    return nullptr;
   }
-  if (mem == nullptr)
-    mem = arr.GetMKLDNNDataReorder(target_pd);
+  return mem;
+}
+
+const mkldnn::memory *GetWeights(const NDArray &arr,
+                                 const mkldnn::memory::primitive_desc 
&target_pd, int num_groups) {
+  const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd);
+  // If the weight array already uses the target layout, simply return it
+  // directly.
+  if (mem) return mem;
+  mem = GetWeights(arr, num_groups);
+  if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_pd);
   if (mem->get_primitive_desc() == target_pd) return mem;
 
   auto ret = TmpMemMgr::Get()->Alloc(target_pd);
@@ -315,6 +308,7 @@ mkldnn_memory_format_t GetDefaultFormat(const 
mkldnn::memory::desc &desc) {
       case mkldnn_oIhw8i:
       case mkldnn_oIhw16i:
       case mkldnn_OIhw8i8o:
+      case mkldnn_hwio_s8s8:
       case mkldnn_OIhw16i16o:
       case mkldnn_OIhw4i16o4i:
       case mkldnn_OIhw4i16o4i_s8s8:
@@ -337,9 +331,11 @@ mkldnn_memory_format_t GetDefaultFormat(const 
mkldnn::memory::desc &desc) {
     switch (desc.data.format) {
       case mkldnn_goihw:
       case mkldnn_hwigo:
+      case mkldnn_hwigo_s8s8:
       case mkldnn_gOIhw8i8o:
       case mkldnn_gOIhw16i16o:
       case mkldnn_gOIhw4i16o4i:
+      case mkldnn_gOIhw4i16o4i_s8s8:
       case mkldnn_gOIhw8i16o2i:
       case mkldnn_gOIhw8o16i2o:
       case mkldnn_gOIhw8o8i:
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h 
b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
index 971c66ad9dd..a27dced910c 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
@@ -85,23 +85,28 @@ static inline bool IsOutputUInt8(const MKLDNNConvParam 
&mkldnn_param) {
          mkldnn_param.with_postsum_relu;
 }
 
-mkldnn::convolution_forward::primitive_desc
-GetConvFwdImpl(const MKLDNNConvFullParam &param, const bool is_train,
-               const NDArray &data, const NDArray &weights, const NDArray 
*bias,
-               const NDArray &output);
+mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const 
MKLDNNConvFullParam &param,
+                                                           const bool is_train,
+                                                           const NDArray &data,
+                                                           const NDArray 
&weights,
+                                                           const NDArray *bias,
+                                                           const NDArray 
&output);
 
 class MKLDNNConvForward {
  public:
   mkldnn::convolution_forward::primitive_desc fwd_pd;
 
-  MKLDNNConvForward(const MKLDNNConvFullParam &param, const bool is_train,
-                    const NDArray &data, const NDArray &weights,
-                    const NDArray *bias, const NDArray &output)
-      : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) {}
+  MKLDNNConvForward(const MKLDNNConvFullParam &param, const bool is_train, 
const NDArray &data,
+                    const NDArray &weights, const NDArray *bias, const NDArray 
&output);
 
   void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
                  const mkldnn::memory *bias, const mkldnn::memory &output);
 
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
+    this->data_->set_data_handle(data.get_data_handle());
+    this->out_->set_data_handle(output.get_data_handle());
+  }
+
   const mkldnn::convolution_forward &GetFwd() const {
     return *fwd_;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc 
b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index dd1f3ec07d7..955dfcf5d71 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -42,18 +42,12 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, 
const NDArray &input) {
   return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4;
 }
 
-mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
-    const MKLDNNConvFullParam &param, const bool is_train,
-    const NDArray &data, const NDArray &weights, const NDArray *bias,
-    const NDArray &output) {
-  auto prop = is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
-  auto data_md = GetMemDesc(data);
-  auto weight_md = GetWeightDesc(weights, param.conv_param.num_group);
-  auto out_md = GetMemDesc(output);
+static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
+    const MKLDNNConvFullParam &param, const bool is_train, const 
mkldnn::memory::desc &data_md,
+    const mkldnn::memory::desc &weight_md, const mkldnn::memory::desc *bias_md,
+    const mkldnn::memory::desc &out_md) {
   auto engine = CpuEngine::Get()->get_engine();
-  CHECK_GE(param.conv_param.stride.ndim(), 2U);
-  CHECK_GE(param.conv_param.pad.ndim(), 2U);
-  CHECK_GE(param.conv_param.dilate.ndim(), 2U);
+  auto prop = is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
   mkldnn::memory::dims strides{0, 0};
   strides[0] = param.conv_param.stride[0];
   strides[1] = param.conv_param.stride[1];
@@ -63,18 +57,18 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
   mkldnn::primitive_attr attr;
   mkldnn::post_ops ops;
   if (param.mkldnn_param.with_relu) {
-    float scale = 1.0f;            // for fp32, scale is 1.
-    float alpha = 0.0f;            // negative slope for mkldnn_eltwise_relu.
-    float beta = 1.0f;             // ignored for mkldnn_eltwise_relu.
+    float scale = 1.0f;  // for fp32, scale is 1.
+    float alpha = 0.0f;  // negative slope for mkldnn_eltwise_relu.
+    float beta = 1.0f;   // ignored for mkldnn_eltwise_relu.
     ops.append_eltwise(scale, eltwise_relu, alpha, beta);
   }
   if (param.mkldnn_param.with_sum) {
     ops.append_sum(param.sum_scale);
   }
   if (param.mkldnn_param.with_postsum_relu) {
-    float scale = 1.0f;            // for fp32, scale is 1.
-    float alpha = 0.0f;            // negative slope for mkldnn_eltwise_relu.
-    float beta = 1.0f;             // ignored for mkldnn_eltwise_relu.
+    float scale = 1.0f;  // for fp32, scale is 1.
+    float alpha = 0.0f;  // negative slope for mkldnn_eltwise_relu.
+    float beta = 1.0f;   // ignored for mkldnn_eltwise_relu.
     ops.append_eltwise(scale, eltwise_relu, alpha, beta);
   }
   attr.set_post_ops(ops);
@@ -85,62 +79,67 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
     attr.set_int_output_round_mode(round_nearest);
   }
 
-  // MKL-DNN introduced padded formats since 0.15 which require more memory
-  // for computation compared with the actual tensor size. Currently, MKL-DNN
-  // operators are still reusing those memory from memory planning and the
-  // memory size may smaller than what MKL-DNN kernels require. So here we need
-  // select suboptimal kernel for computation according to tensor sizes.
-  if (param.conv_param.dilate.ndim() == 0 && bias == nullptr) {
-    mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct,
-        data_md, weight_md, out_md, strides, padding, padding, 
mkldnn::padding_kind::zero);
-    auto conv_pd =  mkldnn::convolution_forward::primitive_desc(desc, attr, 
engine);
-    while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) ||
-           conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-           conv_pd.weights_primitive_desc().get_size() != 
GetArraySize(weights)) {
-      CHECK(conv_pd.next_impl()) << "No implementation";
-    }
-    return conv_pd;
+  if (param.conv_param.dilate.ndim() == 0 && bias_md == nullptr) {
+    mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct, data_md,
+                                           weight_md, out_md, strides, 
padding, padding,
+                                           mkldnn::padding_kind::zero);
+    return mkldnn::convolution_forward::primitive_desc(desc, attr, engine);
   } else if (param.conv_param.dilate.ndim() == 0) {
-    auto bias_md = GetMemDesc(*bias);
-    mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct,
-        data_md, weight_md, bias_md, out_md, strides, padding, padding,
-        mkldnn::padding_kind::zero);
-    auto conv_pd =  mkldnn::convolution_forward::primitive_desc(desc, attr, 
engine);
-    while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) ||
-           conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-           conv_pd.weights_primitive_desc().get_size() != 
GetArraySize(weights)) {
-      CHECK(conv_pd.next_impl()) << "No implementation";
-    }
-    return conv_pd;
+    mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct, data_md,
+                                           weight_md, *bias_md, out_md, 
strides, padding, padding,
+                                           mkldnn::padding_kind::zero);
+    return mkldnn::convolution_forward::primitive_desc(desc, attr, engine);
   } else {
     mkldnn::memory::dims dilates{0, 0};
     dilates[0] = param.conv_param.dilate[0] - 1;
     dilates[1] = param.conv_param.dilate[1] - 1;
-    if (bias == nullptr) {
-      mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct,
-          data_md, weight_md, out_md, strides, dilates, padding, padding,
-          mkldnn::padding_kind::zero);
-      auto conv_pd =  mkldnn::convolution_forward::primitive_desc(desc, attr, 
engine);
-      while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) ||
-             conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-             conv_pd.weights_primitive_desc().get_size() != 
GetArraySize(weights)) {
-        CHECK(conv_pd.next_impl()) << "No implementation";
-      }
-      return conv_pd;
-    } else {
-      auto bias_md = GetMemDesc(*bias);
-      mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct,
-                                             data_md, weight_md, bias_md, 
out_md, strides,
-                                             dilates, padding, padding,
+    if (bias_md == nullptr) {
+      mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct, data_md,
+                                             weight_md, out_md, strides, 
dilates, padding, padding,
                                              mkldnn::padding_kind::zero);
-      auto conv_pd =  mkldnn::convolution_forward::primitive_desc(desc, attr, 
engine);
-      while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) ||
-             conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
-             conv_pd.weights_primitive_desc().get_size() != 
GetArraySize(weights)) {
-        CHECK(conv_pd.next_impl()) << "No implementation";
-      }
-      return conv_pd;
+      return mkldnn::convolution_forward::primitive_desc(desc, attr, engine);
+    } else {
+      mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct, data_md,
+                                             weight_md, *bias_md, out_md, 
strides, dilates, padding,
+                                             padding, 
mkldnn::padding_kind::zero);
+      return mkldnn::convolution_forward::primitive_desc(desc, attr, engine);
+    }
+  }
+}
+
+mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const 
MKLDNNConvFullParam &param,
+                                                           const bool 
is_train, const NDArray &data,
+                                                           const NDArray 
&weights,
+                                                           const NDArray *bias,
+                                                           const NDArray 
&output) {
+  CHECK_GE(param.conv_param.stride.ndim(), 2U);
+  CHECK_GE(param.conv_param.pad.ndim(), 2U);
+  CHECK_GE(param.conv_param.dilate.ndim(), 2U);
+  auto data_md = GetMemDesc(data);
+  auto weight_md = GetWeightDesc(weights, param.conv_param.num_group, 
param.mkldnn_param.quantized);
+  auto out_md = GetMemDesc(output);
+  auto bias_md =
+      bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, 
mshadow::kInt32) : GetMemDesc(*bias))
+           : mkldnn::memory::desc{
+             {}, mkldnn::memory::data_type::data_undef, 
mkldnn::memory::format::any};
+  auto bias_md_ptr = bias ? &bias_md : nullptr;
+  try {
+    auto conv_pd = GetConvFwdImpl(param, is_train, data_md, weight_md, 
bias_md_ptr, out_md);
+    while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) ||
+           conv_pd.src_primitive_desc().get_size() != GetArraySize(data) ||
+           (!param.mkldnn_param.quantized &&
+            conv_pd.weights_primitive_desc().get_size() != 
GetArraySize(weights))) {
+      CHECK(conv_pd.next_impl()) << "No convolution implementation for this 
request.";
     }
+    return conv_pd;
+  } catch (mkldnn::error &e) {
+    if (e.status == mkldnn_unimplemented && param.mkldnn_param.quantized) {
+      LOG(ERROR) << "AVX512-BW support or Intel(R) MKL dependency is "
+                    "required for int8 convolution";
+    } else {
+      LOG(ERROR) << e.message;
+    }
+    throw;
   }
 }
 
@@ -270,48 +269,31 @@ static 
mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
   }
 }
 
-void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data,
-                                  const mkldnn::memory &weight,
-                                  const mkldnn::memory *bias,
-                                  const mkldnn::memory &output) {
-  if (this->data_ == nullptr)
-    this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-            fwd_pd.src_primitive_desc(), data.get_data_handle()));
-  else
-    this->data_->set_data_handle(data.get_data_handle());
-
-  if (this->weight_ == nullptr)
-    this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-            fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
-  else
-    this->weight_->set_data_handle(weight.get_data_handle());
-
-  if (this->out_ == nullptr)
-    this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-            fwd_pd.dst_primitive_desc(), output.get_data_handle()));
-  else
-    this->out_->set_data_handle(output.get_data_handle());
-
-  if (bias != nullptr) {
-    if (this->bias_ == nullptr)
-      this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
-    else
-      this->bias_->set_data_handle(bias->get_data_handle());
-    if (this->fwd_ == nullptr)
-      this->fwd_ = std::shared_ptr<mkldnn::convolution_forward>(
-          new mkldnn::convolution_forward(fwd_pd, 
mkldnn::primitive::at(*this->data_),
-                                          
mkldnn::primitive::at(*this->weight_),
-                                          mkldnn::primitive::at(*this->bias_),
-                                          *this->out_));
-  } else if (this->fwd_ == nullptr) {
-    this->fwd_ = std::shared_ptr<mkldnn::convolution_forward>(
-        new mkldnn::convolution_forward(fwd_pd, 
mkldnn::primitive::at(*this->data_),
-                                        mkldnn::primitive::at(*this->weight_),
-                                        *this->out_));
+MKLDNNConvForward::MKLDNNConvForward(const MKLDNNConvFullParam &param, const 
bool is_train,
+                                     const NDArray &data, const NDArray 
&weights,
+                                     const NDArray *bias, const NDArray 
&output)
+    : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) {
+  data_ = std::make_shared<mkldnn::memory>(fwd_pd.src_primitive_desc(), 
nullptr);
+  weight_ = std::make_shared<mkldnn::memory>(fwd_pd.weights_primitive_desc(), 
nullptr);
+  out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(), 
nullptr);
+  if (bias) {
+    bias_ = std::make_shared<mkldnn::memory>(fwd_pd.bias_primitive_desc(), 
nullptr);
+    fwd_ = std::make_shared<mkldnn::convolution_forward>(fwd_pd, *this->data_, 
*this->weight_,
+                                                         *this->bias_, 
*this->out_);
+  } else {
+    fwd_ = std::make_shared<mkldnn::convolution_forward>(fwd_pd, *this->data_, 
*this->weight_,
+                                                         *this->out_);
   }
 }
 
+void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, const 
mkldnn::memory &weight,
+                                  const mkldnn::memory *bias, const 
mkldnn::memory &output) {
+  data_->set_data_handle(data.get_data_handle());
+  weight_->set_data_handle(weight.get_data_handle());
+  out_->set_data_handle(output.get_data_handle());
+  if (bias != nullptr) bias_->set_data_handle(bias->get_data_handle());
+}
+
 MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
                               const bool is_train, const NDArray &data,
                               const NDArray &weights, const NDArray *bias,
diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h 
b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
index 89c3c199488..3c65172c611 100644
--- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
@@ -74,6 +74,9 @@ static void MKLDNNDequantizeComputeKer(const 
std::vector<NDArray> &inputs,
     i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
   }
   mkldnn::memory::format i_fmt = 
static_cast<mkldnn::memory::format>(i_desc.data.format);
+  if (i_fmt == mkldnn::memory::format::nhwc) {
+    i_fmt = mkldnn::memory::format::nchw;
+  }
   auto o_desc = mkldnn::memory::desc(i_dims,
                                     
(mkldnn::memory::data_type)data_type_enum<DstType>::type,
                                     i_fmt);
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h 
b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
new file mode 100644
index 00000000000..806f67f01eb
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_quantize_v2-inl.h
+ * \brief
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
+#if MXNET_USE_MKLDNN == 1
+#include <algorithm>
+#include <string>
+#include <vector>
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+#include "../quantize_v2-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template <typename SrcType, typename DstType>
+static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
+                                     const std::vector<NDArray>& outputs,
+                                     const QuantizeV2Param& param,
+                                     const std::vector<OpReqType>& req) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using red::limits::MaxValue;
+  using red::limits::MinValue;
+  float real_range = 0.0;
+  float quantized_range = 0.0;
+  NDArray in_buffer = inputs[0];
+  float data_min = red::limits::MaxValue<float>();
+  float data_max = red::limits::MinValue<float>();
+
+  if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+    data_min = param.min_calib_range.value();
+    data_max = param.max_calib_range.value();
+  } else {
+    // no calib info
+    in_buffer = inputs[0].Reorder2Default();
+    auto in_ptr = in_buffer.data().dptr<float>();
+    auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+    std::vector<float> data_maxs(nthreads, data_max);
+    std::vector<float> data_mins(nthreads, data_min);
+#pragma omp parallel for num_threads(nthreads)
+    for (index_t i = 0; i < static_cast<index_t>(in_buffer.shape().Size()); 
i++) {
+      int tid = omp_get_thread_num();
+      if (in_ptr[i] > data_maxs[tid]) data_maxs[tid] = in_ptr[i];
+      if (in_ptr[i] < data_mins[tid]) data_mins[tid] = in_ptr[i];
+    }
+    for (index_t i = 0; i < nthreads; i++) {
+      if (data_maxs[i] > data_max) data_max = data_maxs[i];
+      if (data_mins[i] < data_min) data_min = data_mins[i];
+    }
+  }
+  auto out_type = GetOutputType(param);
+  if (out_type == mshadow::kUint8) {
+    real_range = MaxAbs(data_min, data_max);
+    quantized_range = MaxAbs(MaxValue<DstType>(), MinValue<DstType>());
+    *outputs[1].data().dptr<float>() = data_min;
+    *outputs[2].data().dptr<float>() = data_max;
+  } else if (out_type == mshadow::kInt8) {
+    real_range = MaxAbs(data_min, data_max);
+    quantized_range = MinAbs(MaxValue<DstType>(), MinValue<DstType>());
+    *outputs[1].data().dptr<float>() = -real_range;
+    *outputs[2].data().dptr<float>() = real_range;
+  } else {
+    LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output 
type";
+  }
+  float scale = quantized_range / real_range;
+
+  primitive_attr attr;
+  const int mask = 0;
+  std::vector<float> scales = {scale};
+  attr.set_output_scales(mask, scales);
+  attr.set_int_output_round_mode(round_nearest);
+  mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
+
+  if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = 
inputs[0].Reorder2Default();
+  auto i_mem = in_buffer.GetMKLDNNData();
+  auto i_mpd = i_mem->get_primitive_desc();
+  auto i_desc = i_mpd.desc();
+  mkldnn::memory::format i_fmt = 
static_cast<mkldnn::memory::format>(i_desc.data.format);
+  if (i_fmt == mkldnn::memory::format::nchw ||
+      i_fmt == mkldnn::memory::format::nChw8c ||
+      i_fmt == mkldnn_nChw16c) {
+    i_fmt = mkldnn::memory::format::nhwc;
+  }
+  size_t i_ndim = in_buffer.shape().ndim();
+  mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
+  for (size_t i = 0; i < i_ndim; i++) {
+    i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
+  }
+  auto o_desc =
+      mkldnn::memory::desc(i_dims, 
(mkldnn::memory::data_type)data_type_enum<DstType>::type, i_fmt);
+  auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
+  auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
+  auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
+  MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, 
*o_mem.second));
+  CommitOutput(outputs[0], o_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const 
OpContext& ctx,
+                                    const std::vector<NDArray>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<NDArray>& outputs) {
+  const QuantizeV2Param& param = nnvm::get<QuantizeV2Param>(attrs.parsed);
+  auto out_type = GetOutputType(param);
+  if (out_type == mshadow::kUint8) {
+    MKLDNNQuantizeComputeKer<float, uint8_t>(inputs, outputs, param, req);
+  } else if (out_type == mshadow::kInt8) {
+    MKLDNNQuantizeComputeKer<float, int8_t>(inputs, outputs, param, req);
+  } else {
+    LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output 
type";
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
diff --git a/src/operator/quantization/quantization_utils.h 
b/src/operator/quantization/quantization_utils.h
index ee711220589..efc84100970 100644
--- a/src/operator/quantization/quantization_utils.h
+++ b/src/operator/quantization/quantization_utils.h
@@ -27,6 +27,7 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include "../mxnet_op.h"
+#include "../tensor/broadcast_reduce_op.h"
 
 namespace mxnet {
 namespace op {
@@ -171,6 +172,20 @@ struct QuantizationRangeForMultiplicationStruct {
   }
 };
 
+template<typename xpu, typename DType>
+inline size_t ConfigReduce(mshadow::Stream<xpu>* s,
+                           const TShape& data_shape,
+                           const TShape& out_shape,
+                           TShape* src_shape,
+                           TShape* dst_shape) {
+  BroadcastReduceShapeCompact(data_shape, out_shape, src_shape, dst_shape);
+  constexpr int NDim = 2;
+  CHECK_EQ(src_shape->ndim(), NDim);
+  CHECK_EQ(dst_shape->ndim(), NDim);
+
+  return broadcast::ReduceWorkspaceSize<NDim, DType>(s, *dst_shape, kWriteTo, 
*src_shape);
+}
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_QUANTIZATION_QUANTIZATION_UTILS_H_
diff --git a/src/operator/quantization/quantize_graph_pass.cc 
b/src/operator/quantization/quantize_graph_pass.cc
index fcd0fb4218b..af533978a6f 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -26,6 +26,7 @@
 #include <nnvm/pass.h>
 #include <mxnet/op_attr_types.h>
 #include <unordered_set>
+#include "quantize_v2-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -63,12 +64,12 @@ NodePtr InsertNode(std::string op_name,
 }
 
 std::vector<NodeEntry> OfflineParams(std::vector<NodeEntry>&& outputs,
-                                     std::unordered_set<std::string>&& 
offline_params) {
+                                     const std::unordered_set<std::string>& 
offline_params) {
   std::string node_suffixs[3] = {"", "_min", "_max"};
   std::unordered_map<Node*, NodePtr> mirror_map;
   nnvm::NodeEntryMap<NodePtr> entry_var;
   auto need_offline = [&](NodePtr n) {
-    return (n->op() == Op::Get("_contrib_quantize")) &&
+    return (n->op() == Op::Get("_contrib_quantize_v2")) &&
            n->inputs[0].node->is_variable() &&
            offline_params.count(n->inputs[0].node->attrs.name);
   };
@@ -88,7 +89,8 @@ std::vector<NodeEntry> OfflineParams(std::vector<NodeEntry>&& 
outputs,
   return outputs;
 }
 
-inline bool NeedQuantize(NodePtr node, const std::unordered_set<std::string>& 
excluded_nodes) {
+inline bool NeedQuantize(const NodePtr node,
+                         const std::unordered_set<std::string>& 
excluded_nodes) {
   static auto& quantized_op_map = 
Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
   const auto& op = node->op();
@@ -121,10 +123,9 @@ Graph QuantizeGraph(Graph &&src) {
   static const auto& need_requantize_map = 
Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
   static const auto& avoid_quantize_input_map =
       Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
-  auto offline_params = 
src.GetAttr<std::unordered_set<std::string>>("offline_params");
-  auto excluded_nodes = 
src.GetAttr<std::unordered_set<std::string>>("excluded_nodes");
-  auto quantized_dtype = src.GetAttr<std::string>("quantized_dtype");
-  auto calib_quantize = src.GetAttr<bool>("calib_quantize");
+  const auto offline_params = 
src.GetAttr<std::unordered_set<std::string>>("offline_params");
+  const auto excluded_nodes = 
src.GetAttr<std::unordered_set<std::string>>("excluded_nodes");
+  const auto quantized_dtype = src.GetAttr<std::string>("quantized_dtype");
 
   // mirror_map stores the mapping from the currently visited graph to the 
newly created quantized
   // graph. Key is the currently visited graph's node pointer, and value is a 
copied node of the key
@@ -174,24 +175,10 @@ Graph QuantizeGraph(Graph &&src) {
               }
             }
 
-            NodePtr quantize_node = InsertNode("_contrib_quantize",
+            NodePtr quantize_node = InsertNode("_contrib_quantize_v2",
               e.node->attrs.name + suffix + "_quantize", new_node, 
mirror_entry);
             quantize_node->attrs.dict["out_type"] = quantized_dtype;
             quantize_node->op()->attr_parser(&(quantize_node->attrs));
-            if (calib_quantize) {
-              NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + 
suffix + "_min");
-              quantize_node->inputs.emplace_back(NodeEntry{min_var, 0, 0});
-              NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + 
suffix + "_max");
-              quantize_node->inputs.emplace_back(NodeEntry{max_var, 0, 0});
-            } else {
-              NodePtr min_node = InsertNode("min",
-                  e.node->attrs.name + suffix + "_min", quantize_node, 
mirror_entry);
-              min_node->op()->attr_parser(&(min_node->attrs));
-
-              NodePtr max_node = InsertNode("max",
-                  e.node->attrs.name + suffix + "_max", quantize_node, 
mirror_entry);
-              max_node->op()->attr_parser(&(max_node->attrs));
-            }
             mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version};
           }
         } else if (mirror_node->op() == Op::Get("_contrib_dequantize")) {
@@ -269,43 +256,35 @@ Graph QuantizeGraph(Graph &&src) {
       // the new_node.
       *new_node = *node;
       new_node->inputs.clear();
-      if (node->is_variable() && node->attrs.name == "data") {
-        // Insert identity for data to collect calib for it.
-        NodePtr identity_node =
-            CreateNode("identity", new_node->attrs.name + "_id");
-        identity_node->inputs.emplace_back(NodeEntry{new_node, 0, 0});
-        new_node = identity_node;
-      } else {
-        for (const auto& e : node->inputs) {
-          NodePtr mirror_node = mirror_map.at(e.node.get());
-          NodeEntry mirror_entry = NodeEntry{
-            mirror_node, e.index, e.version};
-          // if input node is quantized operator, add dequantize node
-          if (NeedQuantize(e.node, excluded_nodes) &&
-              (mirror_node->op() != Op::Get("_contrib_dequantize"))) {
-            // here we calculate the output number (exclude min/max, in order 
to
-            // calculate min/max index from mirror node) based on assumption 
that
-            // there is only 1min and 1max output from mirror node (which is
-            // currently true)
-            size_t num_outputs = mirror_node->num_outputs() - 2;
-            uint32_t min_index = num_outputs + 2 * e.index;
-            uint32_t max_index = num_outputs + 2 * e.index + 1;
-            NodePtr dequantize_node = CreateNode("_contrib_dequantize",
-              e.node->attrs.name + "_dequantize");
-            dequantize_node->inputs.emplace_back(mirror_entry);
-            dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, 
min_index, 0});
-            dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, 
max_index, 0});
-            dequantize_node->op()->attr_parser(&(dequantize_node->attrs));
+      for (const auto& e : node->inputs) {
+        NodePtr mirror_node = mirror_map.at(e.node.get());
+        NodeEntry mirror_entry = NodeEntry{
+          mirror_node, e.index, e.version};
+        // if input node is quantized operator, add dequantize node
+        if (NeedQuantize(e.node, excluded_nodes) &&
+            (mirror_node->op() != Op::Get("_contrib_dequantize"))) {
+          // here we calculate the output number (exclude min/max, in order to
+          // calculate min/max index from mirror node) based on assumption that
+          // there is only 1min and 1max output from mirror node (which is
+          // currently true)
+          size_t num_outputs = mirror_node->num_outputs() - 2;
+          uint32_t min_index = num_outputs + 2 * e.index;
+          uint32_t max_index = num_outputs + 2 * e.index + 1;
+          NodePtr dequantize_node = CreateNode("_contrib_dequantize",
+            e.node->attrs.name + "_dequantize");
+          dequantize_node->inputs.emplace_back(mirror_entry);
+          dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, 
min_index, 0});
+          dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, 
max_index, 0});
+          dequantize_node->op()->attr_parser(&(dequantize_node->attrs));
 
-            new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0});
-            mirror_map[e.node.get()] = std::move(dequantize_node);
-          } else if (mirror_entry_map.count(e)) {
-            new_node->inputs.emplace_back(
-                NodeEntry{mirror_entry_map[e].node->inputs[0].node, e.index, 
e.version});
-          } else {
-            new_node->inputs.emplace_back(
-                NodeEntry{mirror_node, e.index, e.version});
-          }
+          new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0});
+          mirror_map[e.node.get()] = std::move(dequantize_node);
+        } else if (mirror_entry_map.count(e)) {
+          new_node->inputs.emplace_back(
+              NodeEntry{mirror_entry_map[e].node->inputs[0].node, e.index, 
e.version});
+        } else {
+          new_node->inputs.emplace_back(
+              NodeEntry{mirror_node, e.index, e.version});
         }
       }
     }
@@ -334,8 +313,7 @@ Graph QuantizeGraph(Graph &&src) {
     }
   }
 
-  if (!offline_params.empty()) outputs =
-    OfflineParams(std::move(outputs), std::move(offline_params));
+  if (!offline_params.empty()) outputs = OfflineParams(std::move(outputs), 
offline_params);
 
   Graph ret;
   ret.outputs = std::move(outputs);
@@ -361,7 +339,11 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
           && 
need_requantize_map[quantized_op_node->op()](quantized_op_node->attrs))
           << quantized_op_node->attrs.name << " op must register 
FNeedRequantize attr"
                                               " and the attr func should 
return true";
-      std::string out_data_name = quantized_op_node->attrs.name + "_";
+      const std::string prefix = "quantized_";
+      CHECK(std::equal(prefix.begin(), prefix.end(), 
quantized_op_node->attrs.name.begin()))
+          << "an quantized op should start with `quantized_`";
+
+      std::string out_data_name = 
quantized_op_node->attrs.name.substr(prefix.size()) + "_";
       auto list_output_names_func = flist_outputs.get(quantized_op_node->op(), 
nullptr);
       // Here it's assumed that the quantized_op node only produces three 
outputs:
       // out_data, min_range, and max_range. So we want to get the 
pre-calculated min_calib_range
@@ -381,6 +363,34 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
         node->attrs.dict["max_calib_range"] = 
std::to_string(calib_table_iter->second.second);
         node->op()->attr_parser(&(node->attrs));
       }
+    } else if (node->op() == Op::Get("_contrib_quantize_v2")) {
+      NodePtr float_op_node = node->inputs[0].node;
+      auto float_op_idx = node->inputs[0].index;
+      std::string out_data_name = float_op_node->attrs.name;
+      if (float_op_node->op()) {
+        auto list_output_names_func = flist_outputs.get(float_op_node->op(), 
nullptr);
+        // We want to get the pre-calculated min_range and max_range from the 
calibration table for
+        // out_data. Here we create the output data name same as its 
constructed in
+        // GraphExecutor::ExecuteMonCallback.
+        if (list_output_names_func != nullptr) {
+          std::vector<std::string> names = 
list_output_names_func(float_op_node->attrs);
+          out_data_name += "_" + names[float_op_idx];
+        } else {
+          out_data_name += "_" + std::to_string(float_op_idx);
+        }
+      }
+      const auto calib_table_iter = calib_table.find(out_data_name);
+      if (calib_table_iter != calib_table.end()) {
+        node->attrs.dict["min_calib_range"] = 
std::to_string(calib_table_iter->second.first);
+        node->attrs.dict["max_calib_range"] = 
std::to_string(calib_table_iter->second.second);
+        node->op()->attr_parser(&(node->attrs));
+        const QuantizeV2Param& param = 
nnvm::get<QuantizeV2Param>(node->attrs.parsed);
+        if (param.out_type == QuantizeV2Param::OutType::kUint8 &&
+            param.min_calib_range.value() < 0.0f) {
+          LOG(WARNING) << "Calibration statistics indicates that node `" << 
node->attrs.name
+                       << "` has negative input, consider use `auto` or `int8` 
as out_type";
+        }
+      }
     }
   });
   return g;
diff --git a/src/operator/quantization/quantize_v2-inl.h 
b/src/operator/quantization/quantize_v2-inl.h
new file mode 100644
index 00000000000..9ba2c3a5437
--- /dev/null
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file quantize_v2-inl.h
+ * \brief implementation of quantize operation
+ */
+#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <limits>
+#include "../elemwise_op_common.h"
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "./quantization_utils.h"
+#include "../tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
+  enum OutType { kAuto = 0, kInt8, kUint8 };
+  int out_type;
+  dmlc::optional<float> min_calib_range;
+  dmlc::optional<float> max_calib_range;
+  DMLC_DECLARE_PARAMETER(QuantizeV2Param) {
+    DMLC_DECLARE_FIELD(out_type)
+      .add_enum("auto", kAuto)
+      .add_enum("int8", kInt8)
+      .add_enum("uint8", kUint8)
+      .set_default(kUint8)
+      .describe("Output data type. `auto` can be specified to automatically 
determine output type "
+                "according to min_calib_range.");
+    DMLC_DECLARE_FIELD(min_calib_range)
+      .set_default(dmlc::optional<float>())
+      .describe("The minimum scalar value in the form of float32. If present, 
it will be used to "
+                "quantize the fp32 data into int8 or uint8.");
+    DMLC_DECLARE_FIELD(max_calib_range)
+      .set_default(dmlc::optional<float>())
+      .describe("The maximum scalar value in the form of float32. If present, 
it will be used to "
+                "quantize the fp32 data into int8 or uint8.");
+  }
+};
+
+static mshadow::TypeFlag GetOutputType(const QuantizeV2Param &param) {
+  auto out_type = mshadow::kInt8;
+  if (param.out_type == QuantizeV2Param::OutType::kAuto) {
+    if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
+      if (param.min_calib_range.value() >= 0.0) {
+        out_type = mshadow::kUint8;
+      } else {
+        out_type = mshadow::kInt8;
+      }
+    }
+  } else if (param.out_type == QuantizeV2Param::OutType::kInt8) {
+    out_type = mshadow::kInt8;
+  } else if (param.out_type == QuantizeV2Param::OutType::kUint8) {
+    out_type = mshadow::kUint8;
+  } else {
+    LOG(FATAL) << "Unsupported quantize output type.";
+  }
+  return out_type;
+}
+
+// quantize float to uint8_t
+struct quantize_v2_unsigned {
+  template <typename DstDType, typename SrcDType>
+  MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, 
float *omax_range,
+                                  const SrcDType *in, const float *imin_range,
+                                  const float *imax_range, const double 
min_limit,
+                                  const double max_limit) {
+    using mshadow::red::limits::MaxValue;
+    using mshadow::red::limits::MinValue;
+    const float scale = (max_limit - min_limit) / (*imax_range - *imin_range);
+    out[i] = static_cast<DstDType>((in[i] - *imin_range) * scale + 0.5);
+    *omin_range = *imin_range;
+    *omax_range = *imax_range;
+  }
+};
+
+// keep zero-center
+struct quantize_v2_zero_centered {
+  template <typename DstDType, typename SrcDType>
+  MSHADOW_XINLINE static void Map(int i, DstDType *out, float *omin_range, 
float *omax_range,
+                                  const SrcDType *in, const float *imin_range,
+                                  const float *imax_range, const float 
quantized_range) {
+    float real_range = MaxAbs(*imin_range, *imax_range);
+    float scale = quantized_range / real_range;
+    SrcDType x = in[i];
+    out[i] = static_cast<DstDType>(Sign(x) * Min(Abs(x) * scale + 0.5f, 
quantized_range));
+    *omin_range = -real_range;
+    *omax_range = real_range;
+  }
+};
+
+template <typename xpu>
+void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+                       const std::vector<TBlob> &inputs, const 
std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  typedef float SrcDType;
+  using mshadow::red::limits::MaxValue;
+  using mshadow::red::limits::MinValue;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
+  auto out_type = GetOutputType(param);
+  if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+    auto in_min = param.min_calib_range.value();
+    auto in_max = param.max_calib_range.value();
+    if (out_type == mshadow::kUint8) {
+      Kernel<quantize_v2_unsigned, xpu>::Launch(s, outputs[0].Size(), 
outputs[0].dptr<uint8_t>(),
+                                                outputs[1].dptr<float>(), 
outputs[2].dptr<float>(),
+                                                inputs[0].dptr<float>(), 
&in_min, &in_max,
+                                                MinValue<uint8_t>(), 
MaxValue<uint8_t>());
+    } else if (out_type == mshadow::kInt8) {  // zero-centered quantization
+      Kernel<quantize_v2_zero_centered, xpu>::Launch(
+          s, outputs[0].Size(), outputs[0].dptr<int8_t>(), 
outputs[1].dptr<float>(),
+          outputs[2].dptr<float>(), inputs[0].dptr<float>(), &in_min, &in_max,
+          MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+    } else {
+      LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+    }
+  } else {  // model is not calibrated
+    TShape src_shape, dst_shape;
+    const size_t actual_float_size = sizeof(float);
+    const size_t actual_quantized_size = sizeof(SrcDType);
+    const size_t temp_reduce_size =
+        ConfigReduce<xpu, SrcDType>(s, inputs[0].shape_, TShape({1}), 
&src_shape, &dst_shape);
+    Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, 
char>(
+        Shape1(2 * actual_float_size + 2 * actual_quantized_size + 
temp_reduce_size), s);
+    Tensor<xpu, 1, float> actual_min_float(reinterpret_cast<float 
*>(temp_space.dptr_), Shape1(1),
+                                           s);
+    Tensor<xpu, 1, float> actual_max_float(reinterpret_cast<float 
*>(temp_space.dptr_) + 1,
+                                           Shape1(1), s);
+
+    const int dev_id = ctx.run_ctx.ctx.dev_id;
+    TBlob actual_min_quantized(reinterpret_cast<SrcDType *>(temp_space.dptr_ + 
8), Shape1(1),
+                               xpu::kDevMask, dev_id);
+    TBlob actual_max_quantized(reinterpret_cast<SrcDType *>(temp_space.dptr_ + 
8) + 1, Shape1(1),
+                               xpu::kDevMask, dev_id);
+    Tensor<xpu, 1, char> workspace(
+        temp_space.dptr_ + 2 * actual_float_size + 2 * actual_quantized_size,
+        Shape1(temp_reduce_size), s);
+    broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
+        s, actual_min_quantized.reshape(dst_shape), kWriteTo, workspace,
+        inputs[0].reshape(src_shape));
+    Kernel<QuantizedToFloatStruct, xpu>::Launch(s, 1, actual_min_float.dptr_,
+                                                
actual_min_quantized.dptr<SrcDType>(),
+                                                inputs[1].dptr<float>(), 
inputs[2].dptr<float>());
+
+    broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
+        s, actual_max_quantized.reshape(dst_shape), kWriteTo, workspace,
+        inputs[0].reshape(src_shape));
+    Kernel<QuantizedToFloatStruct, xpu>::Launch(s, 1, actual_max_float.dptr_,
+                                                
actual_max_quantized.dptr<SrcDType>(),
+                                                inputs[1].dptr<float>(), 
inputs[2].dptr<float>());
+    if (out_type == mshadow::kUint8) {
+      Kernel<quantize_v2_unsigned, xpu>::Launch(
+          s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), 
outputs[1].dptr<float>(),
+          outputs[2].dptr<float>(), inputs[0].dptr<float>(), 
actual_min_float.dptr_,
+          actual_max_float.dptr_, MinValue<uint8_t>(), MaxValue<uint8_t>());
+    } else if (out_type == mshadow::kInt8) {  // zero-centered quantization
+      Kernel<quantize_v2_zero_centered, xpu>::Launch(
+          s, outputs[0].Size(), outputs[0].dptr<int8_t>(), 
outputs[1].dptr<float>(),
+          outputs[2].dptr<float>(), inputs[0].dptr<float>(), 
actual_min_float.dptr_,
+          actual_max_float.dptr_, MinAbs(MaxValue<int8_t>(), 
MinValue<int8_t>()));
+    } else {
+      LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+    }
+  }
+}
+
+static inline bool QuantizeV2Shape(const nnvm::NodeAttrs &attrs, 
std::vector<TShape> *in_attrs,
+                                   std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 3U);
+
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+  SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1});
+  SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1});
+  return !shape_is_none(out_attrs->at(0));
+}
+
+static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, 
std::vector<int> *in_attrs,
+                                  std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 3U);
+  const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
+  auto out_type = GetOutputType(param);
+  if (out_type == mshadow::kUint8) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
+  } else if (out_type == mshadow::kInt8) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8);
+  } else {
+    LOG(FATAL) << "Unsupported out_type.";
+  }
+  TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32);
+  TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32);
+  return (*in_attrs)[0] != -1;
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_
diff --git a/src/operator/quantization/quantize_v2.cc 
b/src/operator/quantization/quantize_v2.cc
new file mode 100644
index 00000000000..afa341f8a78
--- /dev/null
+++ b/src/operator/quantization/quantize_v2.cc
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file quantize.cc
+ * \brief
+ */
+#include "./quantize_v2-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_quantize_v2-inl.h"
+#endif
+
+namespace mxnet {
+namespace op {
+DMLC_REGISTER_PARAMETER(QuantizeV2Param);
+
+static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs,
+                         const int dev_mask,
+                         DispatchMode* dispatch_mode,
+                         std::vector<int> *in_attrs,
+                         std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  (*out_attrs)[0] = kDefaultStorage;
+  (*out_attrs)[1] = kDefaultStorage;
+  (*out_attrs)[2] = kDefaultStorage;
+  return true;
+}
+
+NNVM_REGISTER_OP(_contrib_quantize_v2)
+.describe(R"code(Quantize a input tensor from float to `out_type`,
+with user-specified `min_calib_range` and `max_calib_range` or the input range 
collected at runtime.
+
+Output `min_range` and `max_range` are scalar floats that specify the range 
for the input data.
+
+When out_type is `uint8`, the output is calculated using the following 
equation:
+
+`out[i] = (in[i] - min_range) * range(OUTPUT_TYPE) / (max_range - min_range) + 
0.5`,
+
+where `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`.
+
+When out_type is `int8`, the output is calculate using the following equation
+by keep zero centered for the quantized value:
+
+`out[i] = sign(in[i]) * min(abs(in[i] * scale + 0.5f, quantized_range)`,
+
+where
+`quantized_range = MinAbs(max(int8), min(int8))` and
+`scale = quantized_range / MaxAbs(min_range, max_range).`
+
+When out_type is `auto`, the output type is automatically determined by 
min_calib_range if presented.
+If min_calib_range < 0.0f, the output type will be int8, otherwise will be 
uint8.
+If min_calib_range isn't presented, the output type will be int8.
+
+.. Note::
+    This operator only supports forward propogation. DO NOT use it in 
training.)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<QuantizeV2Param>)
+.set_num_inputs(1)
+.set_num_outputs(3)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", QuantizeV2Shape)
+.set_attr<nnvm::FInferType>("FInferType", QuantizeV2Type)
+.set_attr<FInferStorageType>("FInferStorageType", QuantizeV2StorageType)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeV2Compute)
+#endif
+.set_attr<FCompute>("FCompute<cpu>", QuantizeV2Compute<cpu>)
+.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type 
`float32`")
+.add_arguments(QuantizeV2Param::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/quantization/quantize_v2.cu 
b/src/operator/quantization/quantize_v2.cu
new file mode 100644
index 00000000000..ab0cf9c5ad0
--- /dev/null
+++ b/src/operator/quantization/quantize_v2.cu
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file quantize_v2.cu
+ * \brief
+ */
+#include "./quantize_v2-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_contrib_quantize_v2)
+.set_attr<FCompute>("FCompute<gpu>", QuantizeV2Compute<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/quantization/requantize-inl.h 
b/src/operator/quantization/requantize-inl.h
index e07a149f8a6..148453e6325 100644
--- a/src/operator/quantization/requantize-inl.h
+++ b/src/operator/quantization/requantize-inl.h
@@ -87,20 +87,6 @@ struct RequantizeKernel {
   }
 };
 
-template<typename xpu, typename DType>
-inline size_t ConfigReduce(mshadow::Stream<xpu>* s,
-                           const TShape& data_shape,
-                           const TShape& out_shape,
-                           TShape* src_shape,
-                           TShape* dst_shape) {
-  BroadcastReduceShapeCompact(data_shape, out_shape, src_shape, dst_shape);
-  constexpr int NDim = 2;
-  CHECK_EQ(src_shape->ndim(), NDim);
-  CHECK_EQ(dst_shape->ndim(), NDim);
-
-  return broadcast::ReduceWorkspaceSize<NDim, DType>(s, *dst_shape, kWriteTo, 
*src_shape);
-}
-
 template<typename xpu>
 void RequantizeForward(const nnvm::NodeAttrs& attrs,
                        const OpContext& ctx,
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc 
b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index dfa98d1f5ee..2099d1b1ec2 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -43,10 +43,10 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray 
*bias, bool no_bias,
                                 true, beta.dtype());
   const DType *weight_ptr = weight->data().dptr<DType>();
   const DType *bias_ptr = no_bias ? nullptr : bias->data().dptr<DType>();
-  const DType *gamma_ptr = gamma.Reorder2Default().data().dptr<DType>();
-  const DType *beta_ptr = beta.Reorder2Default().data().dptr<DType>();
-  const DType *mean_ptr = mean.Reorder2Default().data().dptr<DType>();
-  const DType *var_ptr = variance.Reorder2Default().data().dptr<DType>();
+  const DType *gamma_ptr = gamma.data().dptr<DType>();
+  const DType *beta_ptr = beta.data().dptr<DType>();
+  const DType *mean_ptr = mean.data().dptr<DType>();
+  const DType *var_ptr = variance.data().dptr<DType>();
   DType *update_weight_ptr = update_weight.data().dptr<DType>();
   DType *update_bias_ptr = update_bias.data().dptr<DType>();
   size_t channel = gamma.shape()[0];
@@ -77,23 +77,17 @@ static inline size_t GetInSumIndex(const 
MKLDNNConvFusionParam &param) {
 }
 
 template <typename DType>
-static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias,
-                                   bool has_bias, float data_scale,
-                                   bool weight_channelwise_scale,
-                                   std::vector<float> *weight_scales) {
+static std::vector<float> GetWeightScales(const NDArray &weight, bool 
weight_channelwise_scale) {
   using red::limits::MaxValue;
   using red::limits::MinValue;
-  const DType *weight_ptr = weight->data().dptr<DType>();
-  NDArray quantized_weight = NDArray(weight->storage_type(), weight->shape(),
-                                     weight->ctx(), true, mshadow::kInt8);
-  int8_t *quan_weight_ptr = quantized_weight.data().dptr<int8_t>();
-  size_t channel = weight->shape()[0];
+  std::vector<float> weight_scales;
+  const DType *weight_ptr = weight.data().dptr<DType>();
+  size_t channel = weight.shape()[0];
 
   // TODO(Zhennan): Handle the case weight is not in dims 4.
-  size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3];
+  size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3];
   std::vector<DType> weight_c_min(channel, MaxValue<DType>());
   std::vector<DType> weight_c_max(channel, MinValue<DType>());
-#pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
   for (int c = 0; c < static_cast<int>(channel); ++c) {
     const DType *p1 = weight_ptr + c * offset;
     for (size_t k = 0; k < offset; ++k) {
@@ -105,16 +99,10 @@ static void QuantizeConvWeightBias(NDArray *weight, 
NDArray *bias,
   }
 
   if (weight_channelwise_scale) {
-    weight_scales->resize(channel);
-#pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+    weight_scales.resize(channel);
     for (int c = 0; c < static_cast<int>(channel); ++c) {
       DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]);
-      weight_scales->at(c) = kInt8Range / weight_range;
-      const DType *fp_ptr = weight_ptr + c * offset;
-      int8_t *quan_ptr = quan_weight_ptr + c * offset;
-      for (size_t k = 0; k < offset; ++k) {
-        quan_ptr[k] = std::round(weight_scales->at(c) * fp_ptr[k]);
-      }
+      weight_scales[c] = kInt8Range / weight_range;
     }
   } else {
     DType total_min = weight_c_min[0];
@@ -123,74 +111,73 @@ static void QuantizeConvWeightBias(NDArray *weight, 
NDArray *bias,
       if (total_min > weight_c_min[c]) total_min = weight_c_min[c];
       if (total_max < weight_c_max[c]) total_max = weight_c_max[c];
     }
-    weight_scales->resize(1);
+    weight_scales.resize(1);
     DType weight_range = MaxAbs(total_min, total_max);
-    weight_scales->at(0) = kInt8Range / weight_range;
-#pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-    for (int c = 0; c < static_cast<int>(channel); ++c) {
-      const DType *fp_ptr = weight_ptr + c * offset;
-      int8_t *quan_ptr = quan_weight_ptr + c * offset;
-      for (size_t k = 0; k < offset; ++k) {
-        quan_ptr[k] = std::round(weight_scales->at(0) * fp_ptr[k]);
-      }
-    }
-  }
-
-  *weight = quantized_weight;
-  if (has_bias) {
-    const DType *bias_ptr = bias->data().dptr<DType>();
-    NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(),
-                                     bias->ctx(), true, mshadow::kInt32);
-    int32_t *quan_bias_ptr = quantized_bias.data().dptr<int32_t>();
-    for (size_t c = 0; c < channel; ++c) {
-      auto weight_scale =
-          weight_channelwise_scale ? weight_scales->at(c) : 
weight_scales->at(0);
-      float bias_scale = weight_scale * data_scale;
-      quan_bias_ptr[c] = std::round(bias_scale * bias_ptr[c]);
-    }
-    *bias = quantized_bias;
+    weight_scales[0] = kInt8Range / weight_range;
   }
+  return weight_scales;
 }
 
-static void ConvFusionFallBackCompute() {
-  LOG(FATAL) << "Don't know how to do ConvFusionFallBackCompute!";
-}
-
-static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam 
&full_param,
-                                          const OpContext &ctx,
-                                          MKLDNNConvForward *fwd,
-                                          const std::vector<NDArray> &inputs,
-                                          const std::vector<OpReqType> &req,
-                                          const std::vector<NDArray> &outputs) 
{
-  if (SupportMKLDNNConv(full_param.conv_param, inputs[0])) {
-    MKLDNNConvolutionForwardFullFeature(full_param, ctx, fwd, inputs, req, 
outputs);
-    return;
+static void ConvertWeightBias2MKLDNN(const MKLDNNConvFullParam &param,
+                                     
mkldnn::convolution_forward::primitive_desc fwd_pd,
+                                     NDArray *weight, NDArray *bias, bool 
has_bias,
+                                     float data_scale, const 
std::vector<float> &weight_scales) {
+  MKLDNNStream *stream = MKLDNNStream::Get();
+  const auto new_weight = NDArray(fwd_pd.weights_primitive_desc());
+  const auto conv_weights_memory = new_weight.GetMKLDNNData();
+  primitive_attr weight_attr;
+  if (weight_scales.size()) {
+    const int weight_mask = (weight_scales.size()) == 1 ? 0 : 1;
+    weight_attr.set_int_output_round_mode(round_mode::round_nearest);
+    weight_attr.set_output_scales(weight_mask, weight_scales);
+  }
+  auto default_weights_memory = GetWeights(*weight, 
param.conv_param.num_group);
+  if (default_weights_memory == nullptr) default_weights_memory = 
weight->GetMKLDNNData();
+  const auto weight_reorder_pd =
+      
mkldnn::reorder::primitive_desc(default_weights_memory->get_primitive_desc(),
+                                      
conv_weights_memory->get_primitive_desc(), weight_attr);
+  stream->RegisterPrim(
+      mkldnn::reorder(weight_reorder_pd, *default_weights_memory, 
*conv_weights_memory));
+
+  NDArray new_bias;
+  if (has_bias && data_scale) {
+    std::vector<float> bias_scales(weight_scales.size());
+    for (size_t c = 0; c < weight_scales.size(); ++c) {
+      bias_scales[c] = weight_scales[c] * data_scale;
+    }
+    new_bias = NDArray(fwd_pd.bias_primitive_desc());
+    const auto conv_bias_memory = new_bias.GetMKLDNNData();
+    const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1;
+    primitive_attr bias_attr;
+    bias_attr.set_int_output_round_mode(round_mode::round_nearest);
+    bias_attr.set_output_scales(bias_mask, bias_scales);
+    auto bias_weights_memory = bias->GetMKLDNNData();
+    auto bias_reorder_pd =
+        
mkldnn::reorder::primitive_desc(bias_weights_memory->get_primitive_desc(),
+                                        
conv_bias_memory->get_primitive_desc(), bias_attr);
+    stream->RegisterPrim(
+        mkldnn::reorder(bias_reorder_pd, *bias_weights_memory, 
*conv_bias_memory));
   }
-  ConvFusionFallBackCompute();
+  stream->Submit();
+  *weight = new_weight;
+  if (has_bias && data_scale) *bias = new_bias;
 }
 
 class SgMKLDNNConvOperator {
  public:
   explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs)
-      : initalized_(false),
-        subgraph_sym_(*attrs.subgraphs[0]),
-        param_(nnvm::get<MKLDNNConvFusionParam>(attrs.parsed)),
-        inplace_(false) {}
+      : subgraph_sym_(*attrs.subgraphs[0]),
+        param_(nnvm::get<MKLDNNConvFusionParam>(attrs.parsed)) {}
 
   void Forward(const OpContext &ctx,
                const std::vector<NDArray> &inputs,
                const std::vector<OpReqType> &req,
                const std::vector<NDArray> &outputs);
 
-  void Backward(const OpContext &ctx, const std::vector<NDArray> &inputs,
-                const std::vector<OpReqType> &req,
-                const std::vector<NDArray> &outputs) {
-    LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports "
-                  "inference computation.";
-  }
-
  private:
-  bool initalized_;
+  bool initalized_{false};
+  bool inplace_{false};
+  bool post_requantize_{false};
   nnvm::Symbol subgraph_sym_;
   MKLDNNConvFusionParam param_;
   std::shared_ptr<MKLDNNConvForward> fwd_;
@@ -200,10 +187,12 @@ class SgMKLDNNConvOperator {
   float cached_data_max_;
   float cached_sum_min_;
   float cached_sum_max_;
+  float cached_output_min_;
+  float cached_output_max_;
   size_t weight_ver_;
   size_t bias_ver_;
+  float data_scale_{0.0f};
   std::vector<float> weight_scales_;
-  bool inplace_;
 };
 
 void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
@@ -239,10 +228,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
   float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized)
                       ? inputs[idx++].data().dptr<float>()[0]
                       : 0.0;
-  float *out_min_ptr =
-      mkldnn_param.quantized ? outputs[kMin].data().dptr<float>() : nullptr;
-  float *out_max_ptr =
-      mkldnn_param.quantized ? outputs[kMax].data().dptr<float>() : nullptr;
   CHECK_EQ(input_size, idx);
   bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias;
   NDArray data = inputs[in_data];
@@ -251,18 +236,22 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
   // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization 
failed.
   if (mkldnn_param.with_sum) {
     if (!initalized_) {
-      auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
-      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
       // TODO(zhennan): Currently, mkldnn fallback mechanism will break 
inplace option,
       // which make check (req[kOut] == kWriteInplace) useless.
+      auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
+      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
       if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
         inplace_ = true;
       }
     }
     if (!inplace_) {
       auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
-      const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
-      output = NDArray(outputs[kOut].GetMKLDNNData());
+      auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+      mkldnn_mem_ptr tmp_mem(
+          new mkldnn::memory(in_mkl_mem->get_primitive_desc(), 
out_mkl_mem->get_data_handle()));
+      MKLDNNStream::Get()->RegisterMem(tmp_mem);
+      mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get());
+      output = NDArray(tmp_mem);
     }
   }
 
@@ -284,19 +273,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
       }
     }
   }
-  bool post_requantize = false;
-  if (mkldnn_param.quantized) {
-    if (mkldnn_param.min_calib_range.has_value() &&
-        mkldnn_param.max_calib_range.has_value()) {
-      post_requantize = true;
-      mkldnn_param.weight_channelwise_scale = true;
-      *out_min_ptr = mkldnn_param.min_calib_range.value();
-      *out_max_ptr = mkldnn_param.max_calib_range.value();
-    } else {
-      mkldnn_param.weight_channelwise_scale = false;
-    }
-  }
-
   if (!initalized_) {
     cached_data_min_ = data_min;
     cached_data_max_ = data_max;
@@ -306,7 +282,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
     cached_weight_ = inputs[in_weight].Reorder2Default();
     weight_ver_ = inputs[in_weight].version();
     if (!conv_param.no_bias) {
-      cached_bias_ = inputs[in_bias].Reorder2Default();
+      cached_bias_ = inputs[in_bias];
       bias_ver_ = inputs[in_bias].version();
     } else {
       cached_bias_ = NDArray();
@@ -327,13 +303,23 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
     // Quantize weight and bias.
     if (mkldnn_param.quantized) {
       CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
+      if (cached_data_min_ < 0.0f) {
+        CHECK_EQ(data.dtype(), mshadow::kInt8)
+            << "Expect int8 when data_min < 0.0, consider quantize model with 
int8.";
+      }
+      if (mkldnn_param.min_calib_range.has_value() && 
mkldnn_param.max_calib_range.has_value()) {
+        cached_output_min_ = mkldnn_param.min_calib_range.value();
+        cached_output_max_ = mkldnn_param.max_calib_range.value();
+        post_requantize_ = true;
+        mkldnn_param.weight_channelwise_scale = true;
+      } else {
+        mkldnn_param.weight_channelwise_scale = false;
+      }
       auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : 
kUint8Range;
-      float data_scale = data_range / MaxAbs(cached_data_min_, 
cached_data_max_);
+      data_scale_ = data_range / MaxAbs(cached_data_min_, cached_data_max_);
       MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
-        QuantizeConvWeightBias<DType>(&cached_weight_, &cached_bias_,
-                                      has_bias, data_scale,
-                                      mkldnn_param.weight_channelwise_scale,
-                                      &weight_scales_);
+        weight_scales_ =
+            GetWeightScales<DType>(cached_weight_, 
mkldnn_param.weight_channelwise_scale);
       });
       // Collect scale.
       size_t channel = cached_weight_.shape()[0];
@@ -341,29 +327,21 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
       float out_range;
       float quantized_out_range;
       float output_scale;
-      if (cached_data_min_ < 0.0) {
-        // TODO(zhennan): Support int8 input when mkldnn supports.
-        LOG(FATAL) << "Can't handle negetive value for QuantizeData";
-      }
       if (mkldnn_param.with_sum) {
         auto quantized_sum_range = cached_sum_min_ < 0 ? kInt8Range : 
kUint8Range;
         sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, 
cached_sum_max_);
       }
-      if (post_requantize) {
-        quantized_out_range =
-            IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range;
-        out_range = MaxAbs(*out_min_ptr, *out_max_ptr);
+      if (post_requantize_) {
+        quantized_out_range = IsOutputUInt8(mkldnn_param) ? kUint8Range : 
kInt8Range;
+        out_range = MaxAbs(cached_output_min_, cached_output_max_);
         output_scale = quantized_out_range / out_range;
-        full_conv_param.requantize_scales.resize(channel);
-        for (size_t c = 0; c < channel; c++) {
-          auto weight_scale = mkldnn_param.weight_channelwise_scale
-                                  ? weight_scales_[c]
-                                  : weight_scales_[0];
-          full_conv_param.requantize_scales[c] =
-              output_scale / data_scale / weight_scale;
+        
full_conv_param.requantize_scales.resize(mkldnn_param.weight_channelwise_scale 
? channel
+                                                                               
        : 1);
+        for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) {
+          full_conv_param.requantize_scales[c] = output_scale / data_scale_ / 
weight_scales_[c];
         }
       } else {
-        output_scale = data_scale * weight_scales_[0];
+        output_scale = data_scale_ * weight_scales_[0];
         full_conv_param.requantize_scales.resize(0);
       }
       if (mkldnn_param.with_sum)
@@ -372,23 +350,44 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
     fwd_.reset(new MKLDNNConvForward(
         full_conv_param, ctx.is_train, data, cached_weight_,
         has_bias ? &cached_bias_ : nullptr, output));
+    ConvertWeightBias2MKLDNN(full_conv_param, fwd_->fwd_pd, &cached_weight_, 
&cached_bias_,
+                             has_bias, data_scale_, weight_scales_);
+    fwd_->SetNewMem(*data.GetMKLDNNData(), *cached_weight_.GetMKLDNNData(),
+                    has_bias ? cached_bias_.GetMKLDNNData() : nullptr,
+                    *output.GetMKLDNNData());
+    initalized_ = true;
   }
-  initalized_ = true;
-  std::vector<NDArray> new_inputs;
-  std::vector<OpReqType> new_req;
-  if (has_bias) {
-    new_inputs = {data, cached_weight_, cached_bias_};
-    new_req = {req[in_data], req[in_weight], req[in_bias]};
+
+  if (!mkldnn_param.quantized) {
+    auto data_mem = 
data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc());
+    mkldnn::memory *mem = 
output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc());
+    fwd_->SetNewMem(*data_mem, *mem);
+    MKLDNNStream::Get()->RegisterPrim(fwd_->GetFwd());
+    MKLDNNStream::Get()->Submit();
   } else {
-    new_inputs = {data, cached_weight_};
-    new_req = {req[in_data], req[in_weight]};
+    std::vector<NDArray> new_inputs;
+    std::vector<OpReqType> new_req;
+    if (has_bias) {
+      new_inputs = {data, cached_weight_, cached_bias_};
+      new_req = {req[in_data], req[in_weight], req[in_bias]};
+    } else {
+      new_inputs = {data, cached_weight_};
+      new_req = {req[in_data], req[in_weight]};
+    }
+    MKLDNNConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), 
new_inputs, new_req,
+                                        {output});
+  }
+  if (post_requantize_) {
+  float *out_min_ptr = outputs[kMin].data().dptr<float>();
+  float *out_max_ptr = outputs[kMax].data().dptr<float>();
+  *out_min_ptr = cached_output_min_;
+  *out_max_ptr = cached_output_max_;
   }
-  ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs,
-                                new_req, {output});
-
   if (mkldnn_param.with_sum) {
     auto out = const_cast<NDArray &>(outputs[kOut]);
-    out.UpdateMKLDNNMemDesc();
+    auto format = static_cast<mkldnn::memory::format>(
+        fwd_->fwd_pd.dst_primitive_desc().desc().data.format);
+    out.UpdateMKLDNNMemDesc(format);
   }
 }
 
@@ -405,7 +404,7 @@ static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs 
&attrs) {
   auto const &param = nnvm::get<MKLDNNConvFusionParam>(attrs.parsed);
   auto num_input = DefaultSubgraphOpNumInputs(attrs);
   if (param.full_conv_param.mkldnn_param.quantized)
-    return num_input + 2 + param.full_conv_param.mkldnn_param.with_sum ? 2 : 0;
+    return num_input + 2 + (param.full_conv_param.mkldnn_param.with_sum ? 2 : 
0);
   else
     return num_input;
 }
@@ -425,6 +424,7 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) 
{
     os << ")";
     throw dmlc::ParamError(os.str());
   }
+  CHECK_EQ(attrs->subgraphs.size(), 1);
   auto subgraph_sym = attrs->subgraphs[0];
   DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) {
     if (node->is_variable()) return;
@@ -442,10 +442,23 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs 
*attrs) {
   attrs->parsed = std::move(param_);
 }
 
-static std::vector<std::string> SgMKLDNNConvListInputNames(
-    const NodeAttrs &attrs) {
+static std::vector<std::string> SgMKLDNNConvListInputNames(const NodeAttrs 
&attrs) {
   auto const &param = nnvm::get<MKLDNNConvFusionParam>(attrs.parsed);
-  std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
+  std::vector<std::string> input_names;
+  input_names.emplace_back("data");
+  input_names.emplace_back("weight");
+  if (!param.full_conv_param.conv_param.no_bias) {
+    input_names.emplace_back("bias");
+  }
+  if (param.full_conv_param.mkldnn_param.with_bn) {
+    input_names.emplace_back("gamma");
+    input_names.emplace_back("beta");
+    input_names.emplace_back("mean");
+    input_names.emplace_back("var");
+  }
+  if (param.full_conv_param.mkldnn_param.with_sum) {
+    input_names.emplace_back("sum");
+  }
   if (param.full_conv_param.mkldnn_param.quantized) {
     input_names.emplace_back("data_min");
     input_names.emplace_back("data_max");
@@ -454,6 +467,7 @@ static std::vector<std::string> SgMKLDNNConvListInputNames(
       input_names.emplace_back("sum_max");
     }
   }
+  CHECK_EQ(input_names.size(), SgMKLDNNConvNumInputs(attrs));
   return input_names;
 }
 
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc 
b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
index e5220f24d34..adfc41bb120 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
@@ -66,17 +66,21 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
   }
 
   bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
-    if (status == kFail || status == kSuccess || new_node.is_variable())
-      return false;
     // If n isn't the last matched node, then we encoutered a internal
     // branch, we should pop out the node behind n and stop fusion.
     if (matched_list.back() != &n) {
-      while (matched_list.back() != &n) {
-        matched_list.pop_back();
+      if (std::find(matched_list.begin(), matched_list.end(), &n) !=
+          matched_list.end()) {
+        while (matched_list.back() != &n) {
+          matched_list.pop_back();
+        }
       }
       status = kSuccess;
       return false;
     }
+    if (status == kFail || status == kSuccess || new_node.is_variable())
+      return false;
+
     // Use status machine to do selection. The status change is
     // kStart -> kBN -> kSum -> kSuccess
     switch (status) {
@@ -99,12 +103,11 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
               nnvm::get<ActivationParam>(new_node.attrs.parsed);
           if (param.act_type == activation::kReLU) {
             matched_list.push_back(&new_node);
-            // If we find conv+relu, then we can't match bn anymore.
-            if (status == kStart) status = kBN;
-            return true;
-          } else {
+            // If we find conv+relu, then we can't match anymore.
+            // TODO(zhennan): mkldnn only supports convolution + relu + sum in
+            // int8, not in fp32. So we disable this pattern at moment.
             status = kSuccess;
-            return false;
+            return true;
           }
         }
         status = kSuccess;
@@ -117,7 +120,15 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
     if (status == kFail) {
       return std::vector<nnvm::Node *>(0);
     } else {
-      return candidates;
+      std::vector<nnvm::Node *> ret;
+      for (auto i : matched_list) {
+        auto non_const_i = const_cast<nnvm::Node *>(i);
+        if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
+            candidates.end()) {
+          ret.push_back(non_const_i);
+        }
+      }
+      return ret;
     }
   }
 };
diff --git a/tests/python/mkl/test_subgraph.py 
b/tests/python/mkl/test_subgraph.py
index be6feaeb94a..313668cb56f 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -35,14 +35,14 @@
 
 DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)]
 
-def check_qsym_calibrated(qsym):
+def check_qsym_calibrated(qsym, out_type):
   assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != 
-1
   for k, v in qsym.attr_dict().items():
     if k.find('quantized_sg_mkldnn_conv') != -1:
       assert 'min_calib_range' in v
       assert 'max_calib_range' in v
     if k.find('_quantize') != -1:
-      assert v['out_type'] == 'uint8'
+      assert v['out_type'] == out_type
 
 def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, 
label_shape):
   mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
@@ -66,7 +66,7 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, 
label_shape):
     output.wait_to_read()
   return mod.get_outputs()
 
-def check_quantize(sym, data_shape, check_conv=True):
+def check_quantize(sym, data_shape, out_type, check_conv=True):
   fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc')
   sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
   sym_sg = sym.get_backend_symbol("MKLDNN")
@@ -99,15 +99,14 @@ def check_quantize(sym, data_shape, check_conv=True):
                                                                    
aux_params=aux_params,
                                                                    
ctx=mx.current_context(),
                                                                    
excluded_sym_names=excluded_sym_names,
-                                                                   
quantized_dtype='uint8',
+                                                                   
quantized_dtype=out_type,
                                                                    
calib_mode='naive',
                                                                    
calib_data=calib_data,
                                                                    
calib_layer=calib_layer,
-                                                                   
calib_quantize_op=True,
                                                                    
num_calib_examples=5)
   qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE")
   if check_conv:
-    check_qsym_calibrated(qsym)
+    check_qsym_calibrated(qsym, out_type)
   quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, 
data_shape, label_shape)
   for i in range(len(ref_out)):
     assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol 
= 1)
@@ -135,8 +134,9 @@ def check_fusion(sym, data_shape, attrs_op):
   for i in range(len(exe.outputs)):
     assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), 
rtol=1e-3, atol=1e-3)
 
-  # fp32 to uint8
-  check_quantize(sym, data_shape)
+  # fp32 to int8
+  for out_type in ('uint8', 'int8', 'auto'):
+    check_quantize(sym, data_shape, out_type)
 
 def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, 
date_shape=(4,4,10,10)):
   for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs):
@@ -475,12 +475,13 @@ def test_pos_conv_bn_sum_relu():
 
 def test_pos_single_concat():
   for data_shape in DATA_SHAPE:
-    net = single_concat(data_shape, 2, 1)
-    check_quantize(net, data_shape, False)
-    net = single_concat(data_shape, 4, 2)
-    check_quantize(net, data_shape, False)
-    net = single_concat(data_shape, 4, 3)
-    check_quantize(net, data_shape, False)
+    for out_type in ('uint8', 'int8', 'auto'):
+      net = single_concat(data_shape, 2, 1)
+      check_quantize(net, data_shape, out_type, False)
+      net = single_concat(data_shape, 4, 2)
+      check_quantize(net, data_shape, out_type, False)
+      net = single_concat(data_shape, 4, 3)
+      check_quantize(net, data_shape, out_type, False)
 
 @with_seed()
 def test_neg_conv_bn():
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 09157396f83..b25c726cbcc 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6731,7 +6731,7 @@ def get_output_names_callback(name, arr):
             output_names.append(py_str(name))
 
         op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
-        op_exe.set_monitor_callback(get_output_names_callback)
+        op_exe.set_monitor_callback(get_output_names_callback, 
monitor_all=False)
         op_exe.forward()
         for output_name, expected_name in zip(output_names, expected_names):
             assert output_name == expected_name
@@ -6769,6 +6769,51 @@ def get_output_names_callback(name, arr):
                             name='pooling')
     check_name(us_sym, ['pooling_output'])
 
+def test_op_all_names_monitor():
+    def check_name(op_sym, expected_names):
+        output_names = []
+
+        def get_output_names_callback(name, arr):
+            output_names.append(py_str(name))
+
+        op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+        op_exe.set_monitor_callback(get_output_names_callback, 
monitor_all=True)
+        op_exe.forward()
+        for output_name, expected_name in zip(output_names, expected_names):
+            assert output_name == expected_name
+
+    data = mx.sym.Variable('data', shape=(10, 3, 10, 10))
+    conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, 
name='conv')
+    check_name(conv_sym, ['data', 'conv_data', 'conv_weight', 'conv_weight', 
'conv_bias', 'conv_bias', 'conv_output'])
+
+    deconv_sym = mx.sym.Deconvolution(data, kernel=(2, 2), num_filter=1, 
name='deconv')
+    check_name(deconv_sym, ['data', 'deconv_data', 'deconv_weight', 
'deconv_weight', 'deconv_output'])
+
+    fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc')
+    check_name(fc_sym, ['data', 'fc_data', 'fc_weight', 'fc_weight', 
'fc_bias', 'fc_bias', 'fc_output'])
+
+    lrn_sym = mx.sym.LRN(data, nsize=1, name='lrn')
+    check_name(lrn_sym, ['data', 'lrn_data', 'lrn_output', 'lrn_tmp_norm'])
+
+    act_sym = mx.sym.Activation(data, act_type='relu', name='act')
+    check_name(act_sym, ['data', 'act_input0', 'act_output'])
+
+    cc_sym = mx.sym.concat(data, data, dim=0, name='concat')
+    check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 
'concat_output'])
+
+    sm_sym = mx.sym.softmax(data, name='softmax')
+    check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output'])
+
+    sa_sym = mx.sym.SoftmaxActivation(data, name='softmax')
+    check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])
+
+    us_sym = mx.sym.UpSampling(data, scale=2, sample_type='nearest',
+                               name='upsampling')
+    check_name(us_sym, ['data', 'upsampling_arg0', 'upsampling_output'])
+
+    us_sym = mx.sym.Pooling(data, kernel=(2, 2), pool_type='avg',
+                            name='pooling')
+    check_name(us_sym, ['data', 'pooling_data', 'pooling_output'])
 
 @with_seed()
 def test_activation():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to