This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f363da3 [BUGFIX] fix npi_concatenate quantization dim/axis (#20383)
f363da3 is described below
commit f363da3333b51392fed1237f842ed9209fd28338
Author: Sylwester Fraczek <[email protected]>
AuthorDate: Wed Jun 30 10:37:24 2021 +0200
[BUGFIX] fix npi_concatenate quantization dim/axis (#20383)
* fix npi_concatenate quantization dim/axis
---
src/operator/quantization/quantized_concat.cc | 23 +++++++++++++++++++++++
tests/python/mkl/subgraphs/test_conv_subgraph.py | 1 -
2 files changed, 23 insertions(+), 1 deletion(-)
diff --git a/src/operator/quantization/quantized_concat.cc
b/src/operator/quantization/quantized_concat.cc
index e4a15cc..f955e7e 100644
--- a/src/operator/quantization/quantized_concat.cc
+++ b/src/operator/quantization/quantized_concat.cc
@@ -24,6 +24,7 @@
*/
#include "../nn/concat-inl.h"
+#include "../numpy/np_matrix_op-inl.h"
namespace mxnet {
namespace op {
@@ -157,5 +158,27 @@ NNVM_REGISTER_OP(Concat)
return node;
});
+NNVM_REGISTER_OP(_npi_concatenate)
+.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
+ const NumpyConcatenateParam& param =
nnvm::get<NumpyConcatenateParam>(attrs.parsed);
+ nnvm::ObjectPtr node = nnvm::Node::Create();
+ if (param.axis.has_value() && param.axis.value() > 0) {
+ node->attrs.op = Op::Get("_contrib_quantized_concat");
+ node->attrs.name = "quantized_" + attrs.name;
+ } else {
+ LOG(INFO) << "Currently, quantized numpy concatenate only supports axis>0,
exclude "
+ << attrs.name << " which axis is " << param.axis;
+ node->attrs.op = nullptr;
+ node->attrs.name = attrs.name;
+ }
+ node->attrs.dict = attrs.dict;
+ node->attrs.dict["dim"] = node->attrs.dict["axis"];
+ node->attrs.dict.erase("axis");
+ if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
+ node->op()->attr_parser(&(node->attrs));
+ }
+ return node;
+});
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/mkl/subgraphs/test_conv_subgraph.py
b/tests/python/mkl/subgraphs/test_conv_subgraph.py
index da08c81..a4efab4 100644
--- a/tests/python/mkl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/mkl/subgraphs/test_conv_subgraph.py
@@ -289,7 +289,6 @@ def test_pos_single_concat_pos_neg(data_shape, out_type):
@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('out_type', ['int8', 'auto'])
[email protected]("Scale doesn't align in numpy for numpy operators")
def test_pos_concat_scale_align(data_shape, out_type):
# concat scale alignment case
class ConcatScaleAlign(nn.HybridBlock):