This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d4e458e  [Quantization] Support zero-size tensor input for 
quantization flow (#15031)
d4e458e is described below

commit d4e458eddcafdb94648cdbd6c10a04101ce71466
Author: ciyong <[email protected]>
AuthorDate: Thu May 23 09:38:57 2019 +0800

    [Quantization] Support zero-size tensor input for quantization flow (#15031)
    
    * [Quantization] Support zero-size tensor input for quantization flow
    
    * Comment out quantized_act and quantized_sum
    
    * retrigger CI
    
    * Add test cases
---
 src/operator/quantization/dequantize-inl.h         |   7 +
 src/operator/quantization/quantize-inl.h           |  11 +-
 src/operator/quantization/quantize_v2-inl.h        |  11 +-
 src/operator/quantization/quantized_activation.cc  |   4 +
 .../quantization/quantized_elemwise_add.cc         |   4 +
 .../quantization/quantized_fully_connected.cc      |  18 +-
 tests/python/quantization/test_quantization.py     | 256 ++++++++-------------
 7 files changed, 146 insertions(+), 165 deletions(-)

diff --git a/src/operator/quantization/dequantize-inl.h 
b/src/operator/quantization/dequantize-inl.h
index 92b74b7..b5f9e38 100644
--- a/src/operator/quantization/dequantize-inl.h
+++ b/src/operator/quantization/dequantize-inl.h
@@ -74,11 +74,18 @@ inline bool DequantizeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), 1U);
 
+  mxnet::TShape dshape = (*in_attrs)[0];
   for (size_t i = 1; i < 3; ++i) {
     SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1));
   }
 
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+
+  if ((*out_attrs)[0].ndim() > 0) {
+    dshape[0] = ((*out_attrs)[0])[0];
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
+  }
+
   return shape_is_known(out_attrs->at(0));
 }
 
diff --git a/src/operator/quantization/quantize-inl.h 
b/src/operator/quantization/quantize-inl.h
index 7b85657..5108b13 100644
--- a/src/operator/quantization/quantize-inl.h
+++ b/src/operator/quantization/quantize-inl.h
@@ -119,13 +119,20 @@ inline bool QuantizeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), 3U);
 
+  mxnet::TShape dshape = (*in_attrs)[0];
   for (size_t i = 1; i < 3; ++i) {
     SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1));
   }
 
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-  SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape{1});
-  SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape{1});
+  SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1));
+  SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(1, 1));
+
+  if ((*out_attrs)[0].ndim() > 0) {
+    dshape[0] = ((*out_attrs)[0])[0];
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
+  }
+
   return shape_is_known(out_attrs->at(0));
 }
 
diff --git a/src/operator/quantization/quantize_v2-inl.h 
b/src/operator/quantization/quantize_v2-inl.h
index a8cbc0b..d8814cc 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -109,9 +109,16 @@ static inline bool QuantizeV2Shape(const nnvm::NodeAttrs 
&attrs, std::vector<TSh
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 3U);
 
+  mxnet::TShape dshape = (*in_attrs)[0];
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-  SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1});
-  SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1});
+  SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1));
+  SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1));
+
+  if ((*out_attrs)[0].ndim() > 0) {
+    dshape[0] = ((*out_attrs)[0])[0];
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
+  }
+
   return !shape_is_none(out_attrs->at(0));
 }
 
diff --git a/src/operator/quantization/quantized_activation.cc 
b/src/operator/quantization/quantized_activation.cc
index 4ab74d0..95c17ed 100644
--- a/src/operator/quantization/quantized_activation.cc
+++ b/src/operator/quantization/quantized_activation.cc
@@ -115,6 +115,9 @@ the float32 data into int8.
 .add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.")
 .add_arguments(ActivationParam::__FIELDS__());
 
+// TODO(zhiyuan): need extra condition check if there's benefited if it's 
switched on
+// Since it's not compute-intensive.
+#if 0
 NNVM_REGISTER_OP(Activation)
 .set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
   ActivationParam param;
@@ -133,6 +136,7 @@ NNVM_REGISTER_OP(Activation)
   }
   return node;
 });
+#endif
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/quantization/quantized_elemwise_add.cc 
b/src/operator/quantization/quantized_elemwise_add.cc
index f821e65..0e7034e 100644
--- a/src/operator/quantization/quantized_elemwise_add.cc
+++ b/src/operator/quantization/quantized_elemwise_add.cc
@@ -125,6 +125,9 @@ and max thresholds representing the threholds for 
quantizing the float32 output
 .add_argument("rhs_max", "NDArray-or-Symbol", "6th input");
 
 
+// TODO(zhangrong): need extra condition check if there's benefited if it's 
switched on
+// Since it's not compute-intensive.
+#if 0
 NNVM_REGISTER_OP(elemwise_add)
 .set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
   nnvm::NodePtr node = nnvm::Node::Create();
@@ -136,6 +139,7 @@ NNVM_REGISTER_OP(elemwise_add)
   }
   return node;
 });
+#endif
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/quantization/quantized_fully_connected.cc 
b/src/operator/quantization/quantized_fully_connected.cc
index ceac0b6..23790ca 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -47,9 +47,10 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& 
attrs,
   CHECK_EQ(in_shape->size(), num_inputs * 3);
   CHECK_EQ(out_shape->size(), 3U);
 
-  CHECK(shape_is_known(in_shape->at(0)))
-    << "QuantizedFullyConnectedOp input data shape must be given";
-  const mxnet::TShape& dshape = in_shape->at(0);
+  mxnet::TShape dshape = (*in_shape)[0];
+  // require data ndim to be known
+  if (!mxnet::ndim_is_known(dshape)) return false;
+
   index_t num_input;
   if (!param.flatten) {
     num_input = dshape[dshape.ndim() - 1];
@@ -57,7 +58,7 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& 
attrs,
     num_input = dshape.ProdShape(1, dshape.ndim());
   }
 
-  TShape wshape = Shape2(param.num_hidden, num_input);
+  mxnet::TShape wshape = Shape2(param.num_hidden, num_input);
   SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
   if (!param.no_bias) {
     mxnet::TShape bshape = Shape1(param.num_hidden);
@@ -65,11 +66,11 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& 
attrs,
   }
 
   for (size_t i = num_inputs; i < 3 * num_inputs; ++i) {
-    SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape{1});
+    SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(1, 1));
   }
 
   if (!param.flatten) {
-    TShape result_shape(dshape);
+    mxnet::TShape result_shape(dshape);
     result_shape[dshape.ndim() - 1] = param.num_hidden;
     SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
   } else {
@@ -77,6 +78,11 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& 
attrs,
   }
   SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1));
   SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1));
+
+  if ((*out_shape)[0].ndim() > 0) {
+    dshape[0] = ((*out_shape)[0])[0];
+    SHAPE_ASSIGN_CHECK(*in_shape, 0, dshape);
+  }
   return true;
 }
 
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index ce93f98..294e107 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -748,9 +748,6 @@ def test_quantize_model_with_forward():
         if is_test_for_native_cpu():
             print('skipped testing test_quantize_model_with_forward for native 
cpu since it is not supported yet')
             return
-        elif qdtype == 'int8' and is_test_for_mkldnn():
-            print('skipped testing test_quantize_model_with_forward for mkldnn 
cpu int8 since it is not supported yet')
-            return
         elif qdtype == 'uint8' and is_test_for_gpu():
             print('skipped testing test_quantize_model_with_forward for gpu 
uint8 since it is not supported yet')
             return
@@ -782,11 +779,16 @@ def test_quantize_model_with_forward():
                     assert 'out_type' in v
                     assert v['out_type'] == qdtype
 
-        def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, 
label_shape):
-            mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
-            mod.bind(for_training=False,
-                     data_shapes=[('data', data_shape)],
-                     label_shapes=[('softmax_label', label_shape)])
+        def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, 
label_shape=None):
+            if label_shape is None:
+                mod = mx.mod.Module(symbol=qsym, label_names=None, 
context=mx.current_context())
+                mod.bind(for_training=False,
+                         data_shapes=[('data', data_shape)])
+            else:
+                mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+                mod.bind(for_training=False,
+                         data_shapes=[('data', data_shape)],
+                         label_shapes=[('softmax_label', label_shape)])
             mod.set_params(qarg_params, qaux_params)
             data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in 
mod.data_shapes]
             batch = mx.io.DataBatch(data, [])
@@ -794,165 +796,109 @@ def test_quantize_model_with_forward():
             for output in mod.get_outputs():
                 output.wait_to_read()
 
-        sym = get_fp32_residual()
         batch_size = 4
-        data_shape = (batch_size, 4, 10, 10)
-        label_shape = (batch_size, 10)
-
         length = batch_size  # specify num of outputs from split op
-        msym = get_fp32_sym_with_multiple_outputs(length)
-        msym_label_shape = (length, 10)
-        msym_data_shape = (length, 4, 4, 10, 10)
+        sym_list = []
+        name_list = []
+        dshape_list = []
+        lshape_list = []
+
+        # sym 1
+        sym_list.append(get_fp32_residual())
+        name_list.append('sym1')
+        dshape_list.append((batch_size, 4, 10, 10))
+        lshape_list.append((batch_size, 10))
+
+        # sym 2
+        sym_list.append(get_fp32_sym_with_multiple_outputs(length))
+        name_list.append('sym2')
+        dshape_list.append((length, 4, 4, 10, 10))
+        lshape_list.append((length, 10))
 
-        for s, dshape, lshape in zip((sym, msym), (data_shape, 
msym_data_shape),
-                                     (label_shape, msym_label_shape)):
-            mod = Module(symbol=s)
-            mod.bind(data_shapes=[('data', dshape)], 
label_shapes=[('softmax_label', lshape)])
+        data = mx.sym.Variable('data')
+        # sym 3
+        sym_list.append(mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, 
name='conv0'))
+        name_list.append('sym3')
+        dshape_list.append((batch_size, 4, 10, 10))
+        lshape_list.append(None)
+
+        # sym 4
+        cell = mx.rnn.LSTMCell(num_hidden=64)
+        outputs, _ = cell.unroll(length, data)
+        sym_list.append(mx.sym.Group(outputs))
+        name_list.append('sym4')
+        dshape_list.append((batch_size, length, 32))
+        lshape_list.append(None)
+
+        for s, dshape, lshape, name in zip(sym_list, dshape_list, lshape_list, 
name_list):
+            if qdtype == 'int8' and is_test_for_mkldnn() and name in ['sym1', 
'sym2', 'sym3']:
+              print('skipped testing test_quantize_model_with_forward for 
mkldnn cpu int8 since it is not supported yet')
+              continue
+
+            if lshape is None:
+                mod = Module(symbol=s, label_names=None)
+                mod.bind(for_training=False,
+                         data_shapes=[('data', dshape)])
+            else:
+                mod = Module(symbol=s)
+                mod.bind(for_training=False,
+                         data_shapes=[('data', dshape)],
+                         label_shapes=[('softmax_label', lshape)])
 
             mod.init_params()
             arg_params, aux_params = mod.get_params()
-            excluded_names = []
-            if mx.current_context() == mx.cpu():
-               excluded_names += ['fc', 'conv1']
-            if mx.current_context() == mx.gpu():
-               excluded_names += ['sum0', 'relu0', 'relu1']
-            excluded_names += ['concat']
-
-            optional_names = ['pool0']
-            for skip_optional_names in [False, True]:
-                exclude_sym_names = []
-                if skip_optional_names:
-                    excluded_sym_names = excluded_names
-                else:
-                    excluded_sym_names = excluded_names + optional_names
-
-                qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=s,
-                                                                               
  arg_params=arg_params,
-                                                                               
  aux_params=aux_params,
-                                                                               
  excluded_sym_names=excluded_sym_names,
-                                                                               
  ctx=mx.current_context(),
-                                                                               
  quantized_dtype=qdtype,
-                                                                               
  calib_mode='none')
-                check_params(arg_params, qarg_params, qsym)
-                check_params(aux_params, qaux_params)
-                check_qsym_forward(qsym, qarg_params, qaux_params, dshape, 
lshape)
-
-                calib_data = mx.nd.random.uniform(shape=dshape)
-                calib_data = NDArrayIter(data=calib_data, 
batch_size=batch_size)
-                calib_data = DummyIter(calib_data)
-                qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=s,
-                                                                               
  arg_params=arg_params,
-                                                                               
  aux_params=aux_params,
-                                                                               
  excluded_sym_names=excluded_sym_names,
-                                                                               
  ctx=mx.current_context(),
-                                                                               
  quantized_dtype=qdtype,
-                                                                               
  calib_mode='naive',
-                                                                               
  calib_data=calib_data,
-                                                                               
  num_calib_examples=20)
-                check_params(arg_params, qarg_params, qsym)
-                check_params(aux_params, qaux_params)
-                check_qsym_calibrated(qsym)
-                check_qsym_qdtype(qsym, qdtype)
-                check_qsym_forward(qsym, qarg_params, qaux_params, dshape, 
lshape)
-
-    for qdtype in ['int8', 'uint8']:
-        check_quantize_model(qdtype)
-
-@with_seed()
-def test_quantize_conv_with_forward():
-    def check_quantize_model(qdtype):
-        if is_test_for_native_cpu():
-            print('skipped testing test_quantize_model_with_forward for native 
cpu since it is not supported yet')
-            return
-        elif qdtype == 'int8' and is_test_for_mkldnn():
-            print('skipped testing test_quantize_model_with_forward for mkldnn 
cpu int8 since it is not supported yet')
-            return
-        elif qdtype == 'uint8' and is_test_for_gpu():
-            print('skipped testing test_quantize_model_with_forward for gpu 
uint8 since it is not supported yet')
-            return
 
-        def check_params(params, qparams, qsym=None):
-            if qsym is None:
-                assert len(params) == len(qparams)
-                for k, v in params.items():
-                    assert k in qparams
-                    assert same(v.asnumpy(), qparams[k].asnumpy())
-            else:
-                qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, 
params, th_dict = {})
-                assert len(qparams) == len(qparams_ground_truth)
-                for k, v in qparams_ground_truth.items():
-                    assert k in qparams
-                    assert same(v.asnumpy(), qparams[k].asnumpy())
+            excluded_sym_names = []
+            # sym3/sym4 doesn't have such layers
+            if name not in ['sym3', 'sym4']:
+                excluded_names = []
+                if mx.current_context() == mx.cpu():
+                   excluded_names += ['fc', 'conv1']
+                if mx.current_context() == mx.gpu():
+                   excluded_names += ['sum0', 'relu0', 'relu1']
+                excluded_names += ['concat']
+
+                optional_names = ['pool0']
+                for skip_optional_names in [False, True]:
+                    exclude_sym_names = []
+                    if skip_optional_names:
+                        excluded_sym_names = excluded_names
+                    else:
+                        excluded_sym_names = excluded_names + optional_names
 
-        def check_qsym_calibrated(qsym):
-            attrs = qsym.attr_dict()
-            for k, v in attrs.items():
-                if k.find('requantize_') != -1:
-                    assert 'min_calib_range' in v
-                    assert 'max_calib_range' in v
-
-        def check_qsym_qdtype(qsym, qdtype):
-            attrs = qsym.attr_dict()
-            for k, v in attrs.items():
-                if k.find('_quantize') != -1:
-                    assert 'out_type' in v
-                    assert v['out_type'] == qdtype
+            qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=s,
+                                                                             
arg_params=arg_params,
+                                                                             
aux_params=aux_params,
+                                                                             
excluded_sym_names=excluded_sym_names,
+                                                                             
ctx=mx.current_context(),
+                                                                             
quantized_dtype=qdtype,
+                                                                             
calib_mode='none')
+            check_params(arg_params, qarg_params, qsym)
+            check_params(aux_params, qaux_params)
+            check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)
 
-        def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
-            mod = mx.mod.Module(symbol=qsym, label_names=None, 
context=mx.current_context())
-            mod.bind(for_training=False,
-                     data_shapes=[('data', data_shape)])
-            mod.set_params(qarg_params, qaux_params)
-            data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in 
mod.data_shapes]
-            batch = mx.io.DataBatch(data, [])
-            mod.forward(batch, is_train=False)
-            for output in mod.get_outputs():
-                output.wait_to_read()
+            calib_data = mx.nd.random.uniform(shape=dshape)
+            calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
+            calib_data = DummyIter(calib_data)
+            qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=s,
+                                                                             
arg_params=arg_params,
+                                                                             
aux_params=aux_params,
+                                                                             
excluded_sym_names=excluded_sym_names,
+                                                                             
ctx=mx.current_context(),
+                                                                             
quantized_dtype=qdtype,
+                                                                             
calib_mode='naive',
+                                                                             
calib_data=calib_data,
+                                                                             
num_calib_examples=20)
+            check_params(arg_params, qarg_params, qsym)
+            check_params(aux_params, qaux_params)
+            check_qsym_calibrated(qsym)
+            check_qsym_qdtype(qsym, qdtype)
+            check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)
 
-        batch_size = 4
-        dshape = (batch_size, 4, 10, 10)
-        data = mx.sym.Variable('data')
-        sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, 
name='conv0')
-
-        mod = Module(symbol=sym, label_names=None)
-        mod.bind(data_shapes=[('data', dshape)])
-
-        mod.init_params()
-        arg_params, aux_params = mod.get_params()
-        excluded_sym_names = []
-
-        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
-                                                                            
arg_params=arg_params,
-                                                                            
aux_params=aux_params,
-                                                                            
excluded_sym_names=excluded_sym_names,
-                                                                            
ctx=mx.current_context(),
-                                                                            
quantized_dtype=qdtype,
-                                                                            
calib_mode='none')
-        check_params(arg_params, qarg_params, qsym)
-        check_params(aux_params, qaux_params)
-        check_qsym_forward(qsym, qarg_params, qaux_params, dshape)
-
-        calib_data = mx.nd.random.uniform(shape=dshape)
-        calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
-        calib_data = DummyIter(calib_data)
-        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
-                                                                            
arg_params=arg_params,
-                                                                            
aux_params=aux_params,
-                                                                            
excluded_sym_names=excluded_sym_names,
-                                                                            
ctx=mx.current_context(),
-                                                                            
quantized_dtype=qdtype,
-                                                                            
calib_mode='naive',
-                                                                            
calib_data=calib_data,
-                                                                            
num_calib_examples=20)
-        check_params(arg_params, qarg_params, qsym)
-        check_params(aux_params, qaux_params)
-        check_qsym_calibrated(qsym)
-        check_qsym_qdtype(qsym, qdtype)
-        check_qsym_forward(qsym, qarg_params, qaux_params, dshape)
-
-    for qdtype in ['uint8', 'int8']:
+    for qdtype in ['int8', 'uint8']:
         check_quantize_model(qdtype)
 
+
 @with_seed()
 def test_quantize_sym_with_calib():
     sym = get_fp32_sym()

Reply via email to