This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0da4b67  [MKLDNN]Add quantized relu (#14604)
0da4b67 is described below

commit 0da4b67ebf5788deef97ecaca5e30cbc9d27660d
Author: zhiyuan-huang <[email protected]>
AuthorDate: Thu Apr 18 20:18:45 2019 +0800

    [MKLDNN]Add quantized relu (#14604)
    
    * add quantized relu
    
    * fix testcase
    
    * add author and skip quantized-relu for gpu
    
    * fix comments
    
    * retrigger ci
    
    * retrigger ci
    
    * comment fix
    
    * retrigger ci
    
    * retrigger ci
---
 src/operator/nn/mkldnn/mkldnn_act-inl.h            |  74 +++++++++++
 src/operator/nn/mkldnn/mkldnn_act.cc               |  91 +++++---------
 .../quantization/mkldnn/mkldnn_quantized_act.cc    |  55 ++++++++
 src/operator/quantization/quantize_graph_pass.cc   |  29 +++--
 src/operator/quantization/quantized_activation.cc  | 138 +++++++++++++++++++++
 tests/python/quantization/test_quantization.py     |  55 +++++++-
 6 files changed, 372 insertions(+), 70 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h 
b/src/operator/nn/mkldnn/mkldnn_act-inl.h
new file mode 100644
index 0000000..6bf30e3
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_act-inl.h
+ * \brief MKLDNN(Quantized) Activation operator based on subgraph
+ * /author Zhiyuan Huang
+*/
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
+
+
+#if MXNET_USE_MKLDNN == 1
+#include <vector>
+#include <utility>
+#include "../activation-inl.h"
+#include "./mkldnn_ops-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param);
+mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
+    const ActivationParam& param, bool is_train,
+    const mkldnn::memory &input_mem, int dtype);
+
+class MKLDNNActForward {
+ public:
+  const mkldnn::eltwise_forward::primitive_desc fwd_pd;
+
+  MKLDNNActForward(const ActivationParam& param, bool is_train,
+                   const NDArray &data, const mkldnn::memory &mem): fwd_pd(
+                       GetActFwdDescImpl(param, is_train, mem, data.dtype())) 
{}
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output);
+  const mkldnn::eltwise_forward &GetFwd() const;
+
+ private:
+  std::shared_ptr<mkldnn::eltwise_forward> fwd_;
+  std::shared_ptr<mkldnn::memory> data_;
+  std::shared_ptr<mkldnn::memory> out_;
+};
+
+typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
+MKLDNNActForward &GetActForward(const ActivationParam& param,
+                                const OpContext &ctx, const NDArray &in_data,
+                                const mkldnn::memory &in_mem);
+
+void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
+                             const NDArray &in_data, const OpReqType &req,
+                             const NDArray &out_data);
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc 
b/src/operator/nn/mkldnn/mkldnn_act.cc
index 8c64888..9ce27fa 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -32,8 +32,7 @@
 #include <string>
 #include <utility>
 #include "../../operator_common.h"
-#include "../activation-inl.h"
-#include "./mkldnn_base-inl.h"
+#include "mkldnn_act-inl.h"
 
 #if MXNET_USE_MKLDNN == 1
 
@@ -58,7 +57,7 @@ bool SupportMKLDNNAct(const ActivationParam& param, const 
NDArray &input) {
   return SupportMKLDNNAct(param);
 }
 
-static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) 
{
+mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
   switch (param.act_type) {
     case activation::kReLU:
       return mkldnn::algorithm::eltwise_relu;
@@ -74,9 +73,7 @@ static inline mkldnn::algorithm GetMKLDNNActAlgo(const 
ActivationParam& param) {
   }
 }
 
-typedef std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> 
mkldnn_act_pdesc_ptr;
-
-static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
+mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
     const ActivationParam& param, bool is_train,
     const mkldnn::memory &input_mem, int dtype) {
   mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
@@ -84,65 +81,41 @@ static mkldnn::eltwise_forward::primitive_desc 
GetActFwdDescImpl(
   auto cpu_engine = data_mpd.get_engine();
 
   auto alg = GetMKLDNNActAlgo(param);
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    DType alpha = 0;
-    mkldnn::eltwise_forward::desc desc = is_train
-        ? mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_training,
-                                        alg, data_md, alpha)
-        : mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_scoring,
-                                        alg, data_md, alpha);
-    return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
-  });
-  LOG(FATAL) << "Unsupported data type for MKLDNN activation";
-  mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
-      mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
+
+  auto prop = is_train ? mkldnn::prop_kind::forward_training :
+                         mkldnn::prop_kind::forward_scoring;
+  auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f);
   return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
 }
 
-typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
-
-class MKLDNNActForward {
-  std::shared_ptr<mkldnn::eltwise_forward> fwd;
-  std::shared_ptr<mkldnn::memory> data;
-  std::shared_ptr<mkldnn::memory> out;
-
- public:
-  const mkldnn::eltwise_forward::primitive_desc fwd_pd;
-
-  MKLDNNActForward(const ActivationParam& param, bool is_train,
-                   const NDArray &data, const mkldnn::memory &mem): fwd_pd(
-                       GetActFwdDescImpl(param, is_train, mem, data.dtype())) {
-  }
-
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
-    if (this->data == nullptr)
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              data.get_primitive_desc(), data.get_data_handle()));
-    else
-      this->data->set_data_handle(data.get_data_handle());
-
-    CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
-    if (this->out == nullptr)
-      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              fwd_pd.dst_primitive_desc(), output.get_data_handle()));
-    else
-      this->out->set_data_handle(output.get_data_handle());
-
-    if (this->fwd == nullptr) {
-      this->fwd = std::shared_ptr<mkldnn::eltwise_forward>(
-          new mkldnn::eltwise_forward(fwd_pd, 
mkldnn::primitive::at(*this->data),
-                                      *this->out));
-    }
+void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const 
mkldnn::memory &output) {
+  if (this->data_ == nullptr)
+    this->data_ = std::make_shared<mkldnn::memory>(data.get_primitive_desc(),
+            data.get_data_handle());
+  else
+    this->data_->set_data_handle(data.get_data_handle());
+
+  CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
+  if (this->out_ == nullptr)
+    this->out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(),
+            output.get_data_handle());
+  else
+    this->out_->set_data_handle(output.get_data_handle());
+
+  if (this->fwd_ == nullptr) {
+    this->fwd_ = std::shared_ptr<mkldnn::eltwise_forward>(
+        new mkldnn::eltwise_forward(fwd_pd, 
mkldnn::primitive::at(*this->data_),
+                                    *this->out_));
   }
+}
 
-  const mkldnn::eltwise_forward &GetFwd() const {
-    return *fwd;
-  }
-};
+const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
+  return *fwd_;
+}
 
-static MKLDNNActForward &GetActForward(const ActivationParam& param,
-                                       const OpContext &ctx, const NDArray 
&in_data,
-                                       const mkldnn::memory &in_mem) {
+MKLDNNActForward &GetActForward(const ActivationParam& param,
+                                const OpContext &ctx, const NDArray &in_data,
+                                const mkldnn::memory &in_mem) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActForward, 
OpHash> fwds;
 #else
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc 
b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
new file mode 100644
index 0000000..bc69cb5
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_quantized_act.cc
+ * \brief MKLDNN(Quantized) Activation operator based on subgraph
+ * /author Zhiyuan Huang
+*/
+#if MXNET_USE_MKLDNN == 1
+
+#include "../../nn/mkldnn/mkldnn_act-inl.h"
+#include "../quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+static void MKLDNNQuantizedActForward(const nnvm::NodeAttrs& attrs,
+                                      const OpContext& ctx,
+                                      const std::vector<NDArray>& in_data,
+                                      const std::vector<OpReqType>& req,
+                                      const std::vector<NDArray>& out_data) {
+  CHECK(in_data[0].dtype() == mshadow::kUint8 ||
+        in_data[0].dtype() == mshadow::kInt8)
+      << "_contrib_quantized_act op only supports uint8 and int8 as input "
+         "type";
+
+  MKLDNNActivationForward(attrs, ctx, in_data[0], req[0], out_data[0]);
+  out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
+  out_data[2].data().dptr<float>()[0] = in_data[2].data().dptr<float>()[0];
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_act)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedActForward);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/quantize_graph_pass.cc 
b/src/operator/quantization/quantize_graph_pass.cc
index 5bd9e8a..7ff2999 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -89,11 +89,12 @@ std::vector<NodeEntry> 
OfflineParams(std::vector<NodeEntry>&& outputs,
   return outputs;
 }
 
-inline bool NeedQuantize(const NodePtr node,
-                         const std::unordered_set<std::string>& 
excluded_nodes) {
+inline NodePtr NeedQuantize(NodePtr node, const 
std::unordered_set<std::string>& excluded_nodes) {
+  std::unordered_map<NodePtr, NodePtr> quantized_node;
   static auto& quantized_op_map = 
Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
   const auto& op = node->op();
+
   if (op && quantized_op_map.count(op)) {
     bool need = true;
     if (excluded_nodes.count(node->attrs.name)) {
@@ -112,14 +113,24 @@ inline bool NeedQuantize(const NodePtr node,
         });
       }
     }
-    return need;
+
+    if (need) {
+      auto n_ptr = quantized_op_map[node->op()];
+      auto tmp_node = n_ptr(node->attrs);
+      if (tmp_node->op()) {
+        quantized_node[node] = tmp_node;
+      } else {
+        quantized_node[node] = nullptr;
+      }
+    } else {
+      quantized_node[node] = nullptr;
+    }
   }
-  return false;
+  return quantized_node[node];
 }
 
 Graph QuantizeGraph(Graph &&src) {
   static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  static const auto& quantized_op_map = 
Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
   static const auto& need_requantize_map = 
Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
   static const auto& avoid_quantize_input_map =
       Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
@@ -136,11 +147,9 @@ Graph QuantizeGraph(Graph &&src) {
     NodePtr new_node = Node::Create();
     // If the currently visited node needs quantization, insert a quantize op 
node before the
     // current node and replace the current node with the quantized version in 
the new graph.
-    if (NeedQuantize(node, excluded_nodes)) {
-      auto fquantized_op = quantized_op_map[node->op()];
-      // If the currently visited node's op registered the FQuantizedOp 
property, new_node is a
-      // quantizated version of a that op, such as quantized_conv2d.
-      new_node = fquantized_op(node->attrs);
+    auto tmp_node = NeedQuantize(node, excluded_nodes);
+    if (tmp_node) {
+      new_node = tmp_node;
 
       // add data into quantized op input
       for (size_t i = 0; i < node->inputs.size(); ++i) {
diff --git a/src/operator/quantization/quantized_activation.cc 
b/src/operator/quantization/quantized_activation.cc
new file mode 100644
index 0000000..4ab74d0
--- /dev/null
+++ b/src/operator/quantization/quantized_activation.cc
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file quantized_activation.cc
+*/
+#include <mxnet/op_attr_types.h>
+#include "../nn/activation-inl.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+bool QuantizedActivationShape(const nnvm::NodeAttrs& attrs,
+                              std::vector<TShape> *in_shape,
+                              std::vector<TShape> *out_shape) {
+  CHECK_EQ(in_shape->size(), 3U);
+  if (shape_is_none(in_shape->at(0))) return false;
+  SHAPE_ASSIGN_CHECK(*in_shape, 1, TShape{1});
+  SHAPE_ASSIGN_CHECK(*in_shape, 2, TShape{1});
+  out_shape->clear();
+  out_shape->push_back((*in_shape)[0]);
+  out_shape->push_back(TShape{1});
+  out_shape->push_back(TShape{1});
+  return true;
+}
+
+bool QuantizedActivationType(const nnvm::NodeAttrs& attrs,
+                             std::vector<int> *in_type,
+                             std::vector<int> *out_type) {
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  CHECK_EQ(in_type->size(), 3U);
+  CHECK_EQ(out_type->size(), 3U);
+  if (param.act_type == activation::kReLU) {
+    TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt8);
+  } else {
+    LOG(FATAL) << "_contrib_quantized_act only supports act_type=relu for now";
+  }
+  TYPE_ASSIGN_CHECK(*in_type, 1, mshadow::kFloat32);
+  TYPE_ASSIGN_CHECK(*in_type, 2, mshadow::kFloat32);
+  TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32);
+  TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32);
+  return true;
+}
+
+inline static bool QuantizedActivationStorageType(const nnvm::NodeAttrs &attrs,
+                                                  const int dev_mask,
+                                                  DispatchMode *dispatch_mode,
+                                                  std::vector<int> *in_attrs,
+                                                  std::vector<int> *out_attrs) 
{
+  CHECK_EQ(in_attrs->size(), 3);
+
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  const ActivationParam &param = nnvm::get<ActivationParam>(attrs.parsed);
+  if (dev_mask == mshadow::cpu::kDevMask && param.act_type == 
activation::kReLU) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#else
+  CHECK_EQ(out_attrs->size(), 3);
+#endif
+  for (int& out_attr : *out_attrs)
+    out_attr = kDefaultStorage;
+  return true;
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_act)
+.describe(R"code(Activation operator for input and output data type of int8.
+The input and output data comes with min and max thresholds for quantizing
+the float32 data into int8.
+
+.. Note::
+     This operator only supports forward propogation. DO NOT use it in 
training.
+     This operator only supports `relu`)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<ActivationParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "min_data", "max_data"};
+  })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"output", "min_output", "max_output"};
+  })
+.set_attr<nnvm::FInferType>("FInferType", QuantizedActivationType)
+.set_attr<mxnet::FInferShape>("FInferShape", QuantizedActivationShape)
+.set_attr<FInferStorageType>("FInferStorageType", 
QuantizedActivationStorageType)
+.set_attr<FNeedRequantize>("FNeedRequantize",
+  [](const NodeAttrs& attrs) {
+    const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+    CHECK(param.act_type == activation::kReLU)
+      << "_contrib_quantized_act only supports act_type=relu for now";
+    return false;
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input data.")
+.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.")
+.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.")
+.add_arguments(ActivationParam::__FIELDS__());
+
+NNVM_REGISTER_OP(Activation)
+.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
+  ActivationParam param;
+  param.Init(attrs.dict);
+  nnvm::NodePtr node = nnvm::Node::Create();
+  if (param.act_type == activation::kReLU) {
+    node->attrs.op = Op::Get("_contrib_quantized_act");
+    node->attrs.name = "quantized_" + attrs.name;
+  } else {
+    node->attrs.op = nullptr;
+    node->attrs.name = attrs.name;
+  }
+  node->attrs.dict = attrs.dict;
+  if (node->op()->attr_parser != nullptr) {
+    node->op()->attr_parser(&(node->attrs));
+  }
+  return node;
+});
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index 757df81..2761e77 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -414,6 +414,57 @@ def test_quantized_flatten():
         check_quantized_flatten((10, 15, 18), qdtype)
         check_quantized_flatten((3, 4, 23, 23), qdtype)
 
+@with_seed()
+def test_quantized_act():
+    def check_quantized_act(data_shape, qdtype):
+        if is_test_for_native_cpu():
+            print('skipped testing quantized_act for native cpu since it is 
not supported yet')
+            return
+        elif qdtype == 'int8' and is_test_for_mkldnn():
+            print('skipped testing quantized_act for mkldnn cpu int8 since it 
is not supported yet')
+            return
+        elif is_test_for_gpu():
+            print('skipped testing quantized_act for gpu since it is not 
supported yet')
+            return
+        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
+        act_fp32 = mx.sym.Activation(data=data, act_type='relu', name='relu')
+        arg_shapes, _, _ = act_fp32.infer_shape(data=data_shape)
+        arg_names = act_fp32.list_arguments()
+        act_fp32_exe = act_fp32.simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+
+        act_fp32_exe.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low,
+                                                high=data_high, 
shape=data_shape).astype(qdtype)
+        output = act_fp32_exe.forward()[0]
+
+        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
+        min_data = mx.sym.Variable(name='min_data')
+        max_data = mx.sym.Variable(name='max_data')
+        quantized_act = mx.sym.contrib.quantized_act(data=qdata, 
min_data=min_data, max_data=max_data, act_type='relu')
+        act_int8_exe = quantized_act.simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        qarg_names = quantized_act.list_arguments()
+
+        act_int8_exe.arg_dict[qarg_names[0]][:] = 
act_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
+        quantized_range_min = 
mx.nd.min(act_int8_exe.arg_dict[qarg_names[0]][:])
+        quantized_range_max = 
mx.nd.max(act_int8_exe.arg_dict[qarg_names[0]][:])
+        act_int8_exe.arg_dict[qarg_names[1]][:] = 
quantized_range_min.astype(qdtype)
+        act_int8_exe.arg_dict[qarg_names[2]][:] = 
quantized_range_max.astype(qdtype)
+        qoutput, min_range, max_range = act_int8_exe.forward()
+
+        assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
+        assert_almost_equal(min_range.asscalar(), 
quantized_range_min.asscalar())
+        assert_almost_equal(max_range.asscalar(), 
quantized_range_max.asscalar())
+
+    for qdtype in ['int8', 'uint8']:
+        check_quantized_act((10,), qdtype)
+        check_quantized_act((10, 15), qdtype)
+        check_quantized_act((10, 15, 18), qdtype)
+        check_quantized_act((3, 4, 23, 23), qdtype)
 
 @with_seed()
 def test_quantize_params():
@@ -634,7 +685,9 @@ def test_quantize_model_with_forward():
             arg_params, aux_params = mod.get_params()
             excluded_names = []
             if mx.current_context() == mx.cpu():
-               excluded_names += ['fc']
+               excluded_names += ['fc', 'conv1']
+            if mx.current_context() == mx.gpu():
+               excluded_names += ['relu0', 'relu1']
             excluded_names += ['concat']
 
             optional_names = ['pool0']

Reply via email to