This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f363da3  [BUGFIX] fix npi_concatenate quantization dim/axis (#20383)
f363da3 is described below

commit f363da3333b51392fed1237f842ed9209fd28338
Author: Sylwester Fraczek <[email protected]>
AuthorDate: Wed Jun 30 10:37:24 2021 +0200

    [BUGFIX] fix npi_concatenate quantization dim/axis (#20383)
    
    * fix npi_concatenate quantization dim/axis
---
 src/operator/quantization/quantized_concat.cc    | 23 +++++++++++++++++++++++
 tests/python/mkl/subgraphs/test_conv_subgraph.py |  1 -
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/src/operator/quantization/quantized_concat.cc 
b/src/operator/quantization/quantized_concat.cc
index e4a15cc..f955e7e 100644
--- a/src/operator/quantization/quantized_concat.cc
+++ b/src/operator/quantization/quantized_concat.cc
@@ -24,6 +24,7 @@
 */
 
 #include "../nn/concat-inl.h"
+#include "../numpy/np_matrix_op-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -157,5 +158,27 @@ NNVM_REGISTER_OP(Concat)
   return node;
 });
 
+NNVM_REGISTER_OP(_npi_concatenate)
+.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
+  const NumpyConcatenateParam& param = 
nnvm::get<NumpyConcatenateParam>(attrs.parsed);
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  if (param.axis.has_value() && param.axis.value() > 0) {
+    node->attrs.op = Op::Get("_contrib_quantized_concat");
+    node->attrs.name = "quantized_" + attrs.name;
+  } else {
+    LOG(INFO) << "Currently, quantized numpy concatenate only supports axis>0, 
exclude "
+              << attrs.name << " which axis is " << param.axis;
+    node->attrs.op = nullptr;
+    node->attrs.name = attrs.name;
+  }
+  node->attrs.dict = attrs.dict;
+  node->attrs.dict["dim"] = node->attrs.dict["axis"];
+  node->attrs.dict.erase("axis");
+  if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
+    node->op()->attr_parser(&(node->attrs));
+  }
+  return node;
+});
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/mkl/subgraphs/test_conv_subgraph.py 
b/tests/python/mkl/subgraphs/test_conv_subgraph.py
index da08c81..a4efab4 100644
--- a/tests/python/mkl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/mkl/subgraphs/test_conv_subgraph.py
@@ -289,7 +289,6 @@ def test_pos_single_concat_pos_neg(data_shape, out_type):
 @mx.util.use_np
 @pytest.mark.parametrize('data_shape', DATA_SHAPE)
 @pytest.mark.parametrize('out_type', ['int8', 'auto'])
[email protected]("Scale doesn't align in numpy for numpy operators")
 def test_pos_concat_scale_align(data_shape, out_type):
   # concat scale alignment case
   class ConcatScaleAlign(nn.HybridBlock):

Reply via email to