huangzhiyuan closed pull request #12956: Add reshape op supported by MKL-DNN
URL: https://github.com/apache/incubator-mxnet/pull/12956
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index a60d6555c74..728bc03669f 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -105,18 +105,14 @@ void MKLDNNCopy(const mkldnn::memory &mem, const 
mkldnn::memory* this_mem) {
     stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
   } else if (!same_shape(this_desc, from_desc)) {
     // In this case, the source memory stores data in a customized layout. We
-    // need to reorganize the data in memory before we can reshape.
-    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, 
from_def_format);
-    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
-    stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
-    // Now we can reshape it
+    // need to reorganize the data in memory and reshape it.
     mkldnn::memory::dims dims(this_desc.data.dims,
                               this_desc.data.dims + this_desc.data.ndims);
     auto this_dtype = 
static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
     auto this_format = 
static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
     mkldnn::memory::desc data_md(dims, this_dtype, this_format);
     mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
-    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
+    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
     stream->RegisterMem(tmp_mem);
     stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
   } else if (from_pd == this_pd) {
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc 
b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 6a70ae40ac8..bead905476b 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -30,6 +30,7 @@
 #include "./mkldnn_ops-inl.h"
 #include "./mkldnn_base-inl.h"
 #include "./mkldnn_convolution-inl.h"
+#include <sys/time.h>
 
 namespace mxnet {
 namespace op {
@@ -281,6 +282,11 @@ void MKLDNNConvolutionForwardFullFeature(const 
MKLDNNConvFullParam &param,
   auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(
       fwd->fwd_pd.src_primitive_desc());
   const mkldnn::memory *weight_mem;
+  auto NumFilter = weight.shape()[0];
+  auto weight_OC = weight.shape()[1];
+  auto weight_K0 = weight.shape()[2];
+  auto weight_K1 = weight.shape()[3];
+
   if (ctx.is_train) {
     // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
     // to the default format for now.
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 77d9bf06e2d..cf48cd17488 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -103,6 +103,45 @@ DMLC_REGISTER_PARAMETER(StackParam);
 DMLC_REGISTER_PARAMETER(SqueezeParam);
 DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
 
+static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<NDArray>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+#if MXNET_USE_MKLDNN == 1
+  if (inputs[0].IsMKLDNNData()) {
+    MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]);
+  } else {
+    // This happens if inputs are supposed to be in MKLDNN format
+    // but MKLDNN doesn't support the data type or the shape. We're
+    // forced to convert it to the default format.
+    FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
+    return;
+  }
+#endif
+}
+
+inline static bool ReshapeStorageType(const nnvm::NodeAttrs &attrs,
+                                      const int dev_mask,
+                                      DispatchMode *dispatch_mode,
+                                      std::vector<int> *in_attrs,
+                                      std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  bool ret = ElemwiseStorageType<1, 1, false, false, false>(attrs, dev_mask, 
dispatch_mode,
+                                                            in_attrs, 
out_attrs);
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask
+      && in_attrs->at(0) == kDefaultStorage
+      && out_attrs->at(0) == kDefaultStorage) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  return ret;
+}
+
 NNVM_REGISTER_OP(Reshape)
 .add_alias("reshape")
 .describe(R"code(Reshapes the input array.
@@ -171,9 +210,19 @@ If the argument `reverse` is set to 1, then the special 
values are inferred from
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<ReshapeParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ReshapeShape)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType)
+#endif
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"})
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU)
+#else
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs) {
     return std::vector<std::pair<int, int> >{{0, 0}};
@@ -182,6 +231,7 @@ If the argument `reverse` is set to 1, then the special 
values are inferred from
   [](const NodeAttrs& attrs){
     return std::vector<bool>{true};
   })
+#endif
 .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.")
 .add_arguments(ReshapeParam::__FIELDS__());
 
@@ -914,7 +964,7 @@ NNVM_REGISTER_OP(depth_to_space)
 .describe(R"code(Rearranges(permutes) data from depth into blocks of spatial 
data.
 Similar to ONNX DepthToSpace operator:
 https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace.
-The output is a new tensor where the values from depth dimension are moved in 
spatial blocks 
+The output is a new tensor where the values from depth dimension are moved in 
spatial blocks
 to height and width dimension. The reverse of this operation is 
``space_to_depth``.
 
 .. math::
@@ -925,7 +975,7 @@ to height and width dimension. The reverse of this 
operation is ``space_to_depth
     y = reshape(x \prime \prime, [N, C / (block\_size ^ 2), H * block\_size, W 
* block\_size])
     \end{gather*}
 
-where :math:`x` is an input tensor with default layout as :math:`[N, C, H, 
W]`: [batch, channels, height, width] 
+where :math:`x` is an input tensor with default layout as :math:`[N, C, H, 
W]`: [batch, channels, height, width]
 and :math:`y` is the output tensor of layout :math:`[N, C / (block\_size ^ 2), 
H * block\_size, W * block\_size]`
 
 Example::
@@ -965,9 +1015,9 @@ Example::
 NNVM_REGISTER_OP(space_to_depth)
 .describe(R"code(Rearranges(permutes) blocks of spatial data into depth.
 Similar to ONNX SpaceToDepth operator:
-https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth 
+https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth
 
-The output is a new tensor where the values from height and width dimension 
are 
+The output is a new tensor where the values from height and width dimension are
 moved to the depth dimension. The reverse of this operation is 
``depth_to_space``.
 
 .. math::
@@ -978,7 +1028,7 @@ moved to the depth dimension. The reverse of this 
operation is ``depth_to_space`
     y = reshape(x \prime \prime, [N, C * (block\_size ^ 2), H / block\_size, W 
/ block\_size])
     \end{gather*}
 
-where :math:`x` is an input tensor with default layout as :math:`[N, C, H, 
W]`: [batch, channels, height, width] 
+where :math:`x` is an input tensor with default layout as :math:`[N, C, H, 
W]`: [batch, channels, height, width]
 and :math:`y` is the output tensor of layout :math:`[N, C * (block\_size ^ 2), 
H / block\_size, W / block\_size]`
 
 Example::
@@ -987,8 +1037,8 @@ Example::
          [12, 18, 13, 19, 14, 20],
          [3, 9, 4, 10, 5, 11],
          [15, 21, 16, 22, 17, 23]]]]
-  
-  
+
+
   space_to_depth(x, 2) = [[[[0, 1, 2],
                             [3, 4, 5]],
                            [[6, 7, 8],
diff --git a/tests/python/quantization/test_subgraph.py 
b/tests/python/quantization/test_subgraph.py
new file mode 100644
index 00000000000..0ee4be6d8b8
--- /dev/null
+++ b/tests/python/quantization/test_subgraph.py
@@ -0,0 +1,202 @@
+import mxnet as mx
+import numpy as np
+import argparse
+import ctypes
+import unittest
+from common import with_seed
+from mxnet.io import NDArrayIter
+from mxnet.module import Module
+from mxnet.symbol import Symbol
+from importlib import import_module
+from numpy.testing import assert_allclose
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str
+from mxnet.test_utils import DummyIter
+
+def check_qsym_calibrated(qsym):
+  attrs = qsym.attr_dict()
+  min_value = 0.0
+  max_value = 0.0
+  assert ''.join(qsym.attr_dict().keys()).find('quantized_') != -1
+  for k, v in attrs.items():
+    if k.find('quantized_sg_mkldnn_conv') != -1:
+      assert 'min_calib_range' in v
+      assert 'max_calib_range' in v
+      min_value = v['min_calib_range']
+      max_value = v['max_calib_range']
+    if k.find('_quantize') != -1:
+      assert v['out_type'] == 'uint8'
+  return float(min_value), float(max_value)
+
+def check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, 
label_shape):
+  mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+  mod.bind(for_training=False,
+           data_shapes=[('data', data_shape)],
+           label_shapes=[('softmax_label', label_shape)])
+  mod.set_params(qarg_params, qaux_params)
+  batch = mx.io.DataBatch(data_val, [])
+  mod.forward(batch, is_train=False)
+  for output in mod.get_outputs():
+    output.wait_to_read()
+  return output
+
+def check_quantize(sym, data_shape, label_shape, data_val, sym_output):
+    mod = Module(symbol=sym)
+    mod.bind(data_shapes=[('data', data_shape)], 
label_shapes=[('softmax_label', label_shape)], for_training=False)
+    mod.init_params()
+    arg_params, aux_params = mod.get_params()
+    excluded_sym_names = []
+    if mx.current_context() == mx.cpu():
+      excluded_sym_names += ['fc']
+    calib_data = mx.nd.random.uniform(shape=data_shape)
+    calib_data = NDArrayIter(data=calib_data)
+    calib_data = DummyIter(calib_data)
+    calib_layer = lambda name: name.endswith('_output')
+    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+                                                                     
arg_params=arg_params,
+                                                                     
aux_params=aux_params,
+                                                                     
ctx=mx.current_context(),
+                                                                     
excluded_sym_names=excluded_sym_names,
+                                                                     
quantized_dtype='uint8',
+                                                                     
calib_mode='naive',
+                                                                     
calib_data=calib_data,
+                                                                     
calib_layer=calib_layer,
+                                                                     
disable_requantize=True,
+                                                                     
calib_quantize_op=True,
+                                                                     
num_calib_examples=20)
+    minVar, maxVar = check_qsym_calibrated(qsym)
+    rtol = (maxVar - minVar) / 256
+    qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, data_val, 
data_shape, label_shape)
+    assert_allclose(qsym_output[0].asnumpy(), sym_output[0].asnumpy(), 
rtol=rtol)
+
+def check_fusion(sym, date_shape, label_shape, name, nofusion=False):
+  exe = sym.simple_bind(mx.cpu(), data=date_shape, grad_req='null')
+  out = SymbolHandle()
+  backend = "MKLDNN"
+  check_call(_LIB.MXGenBackendSubgraph(c_str(backend), sym.handle, 
ctypes.byref(out)))
+  sym_sg = Symbol(out)
+  exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null')
+
+  for k, v in exe.arg_dict.items():
+    v = mx.random.uniform(-1.0, 1.0, shape=v.shape)
+  data_val = [exe.arg_dict['data']]
+
+  fwd = exe.forward(is_train=False)
+  fwd[0].wait_to_read()
+
+  fwd_sg = exe_sg.forward(is_train=False)
+  fwd_sg[0].wait_to_read()
+
+  # Check the result accuracy based on fp32 fusion
+  assert_allclose(fwd[0].asnumpy(), fwd_sg[0].asnumpy(), rtol=0)
+  attrs=sym_sg.attr_dict()
+  if not nofusion:
+    assert 
''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1
+  for k, v in attrs.items():
+    if k.find('sg_mkldnn_conv') != -1:
+      for attr_op in name:
+        assert v[attr_op] == 'true'
+
+  # fp32 to uint8
+  if nofusion:
+    check_quantize(sym, date_shape, label_shape, data_val, fwd[0])
+  else: check_quantize(sym_sg, date_shape, label_shape, data_val, fwd[0])
+
+def single_conv():
+  data = mx.symbol.Variable('data')
+  weight = mx.symbol.Variable('weight')
+  bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, 
name='bn')
+  conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  fc = mx.sym.FullyConnected(data=conv, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+def conv_bn():
+  data = mx.symbol.Variable('data')
+  weight = mx.symbol.Variable('weight')
+  bn1 = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, 
name='bn1')
+  conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  bn = mx.symbol.BatchNorm(data=conv, name="bn")
+  fc = mx.sym.FullyConnected(data=bn, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+def conv_relu():
+  data = mx.symbol.Variable('data')
+  weight = mx.symbol.Variable('weight')
+  bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, 
name='bn')
+  conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu")
+  fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+def conv_sum():
+  data = mx.symbol.Variable('data')
+  weight = mx.symbol.Variable('weight')
+  bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, 
name='bn')
+  conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  conv1 = mx.symbol.Convolution(data=bn, weight=weight, name='conv1', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  sum1 = conv + conv1
+  fc = mx.sym.FullyConnected(data=sum1, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+def conv_bn_relu():
+  data = mx.symbol.Variable('data')
+  weight = mx.symbol.Variable('weight')
+  bn1 = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, 
name='bn1')
+  conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  bn = mx.symbol.BatchNorm(data=conv, name="bn")
+  relu = mx.symbol.Activation(data=bn, name='relu', act_type="relu")
+  fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+def conv_bn_sum_relu():
+  data = mx.symbol.Variable('data')
+  weight = mx.symbol.Variable('weight')
+  bn1 = mx.symbol.BatchNorm(data=data, name="bn1")
+  conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  bn = mx.symbol.BatchNorm(data=conv, name="bn")
+  conv1 = mx.symbol.Convolution(data=bn1, weight=weight, name='conv1', 
num_filter=64, kernel=(3, 3), stride=(1, 1))
+  sum1 = bn + conv1
+  relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu")
+  fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+def int8_pooling():
+  data = mx.symbol.Variable('data')
+  bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, 
name='bn')
+  pool = mx.sym.Pooling(data=bn, kernel=(4, 4), pool_type='avg', name='pool')
+  fc = mx.sym.FullyConnected(data=pool, num_hidden=10, flatten=True, name='fc')
+  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+  return sym
+
+@with_seed()
+def test_sugbraph():
+  conv_attr = ['']
+  conv_relu_attr = ['with_relu']
+  conv_bn_attr = ['with_bn']
+  conv_sum_attr = ['with_sum']
+  conv_bn_relu_attr = ['with_bn', 'with_relu']
+  conv_bn_sum_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn']
+
+  shape = [(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)]
+  label = [(4, 10), (32, 10), (64, 10)]
+
+  for date_shape, label_shape in zip(shape, label):
+    net = conv_bn_sum_relu()
+    check_fusion(net, date_shape, label_shape, conv_bn_sum_relu_attr)
+    net = single_conv()
+    check_fusion(net, date_shape, label_shape, conv_attr)
+    net = conv_relu()
+    check_fusion(net, date_shape, label_shape, conv_relu_attr)
+    net = conv_bn()
+    check_fusion(net, date_shape, label_shape, conv_bn_attr)
+    net = conv_sum()
+    check_fusion(net, date_shape, label_shape, conv_sum_attr)
+    net = conv_bn_relu()
+    check_fusion(net, date_shape, label_shape, conv_bn_relu_attr)
+    net = int8_pooling()
+    check_fusion(net, date_shape, label_shape, '', True)
\ No newline at end of file


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to