This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0da4b67 [MKLDNN]Add quantized relu (#14604)
0da4b67 is described below
commit 0da4b67ebf5788deef97ecaca5e30cbc9d27660d
Author: zhiyuan-huang <[email protected]>
AuthorDate: Thu Apr 18 20:18:45 2019 +0800
[MKLDNN]Add quantized relu (#14604)
* add quantized relu
* fix testcase
* add author and skip quantized-relu for gpu
* fix comments
* retrigger ci
* retrigger ci
* comment fix
* retrigger ci
* retrigger ci
---
src/operator/nn/mkldnn/mkldnn_act-inl.h | 74 +++++++++++
src/operator/nn/mkldnn/mkldnn_act.cc | 91 +++++---------
.../quantization/mkldnn/mkldnn_quantized_act.cc | 55 ++++++++
src/operator/quantization/quantize_graph_pass.cc | 29 +++--
src/operator/quantization/quantized_activation.cc | 138 +++++++++++++++++++++
tests/python/quantization/test_quantization.py | 55 +++++++-
6 files changed, 372 insertions(+), 70 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h
b/src/operator/nn/mkldnn/mkldnn_act-inl.h
new file mode 100644
index 0000000..6bf30e3
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_act-inl.h
+ * \brief MKLDNN(Quantized) Activation operator based on subgraph
+ * /author Zhiyuan Huang
+*/
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
+
+
+#if MXNET_USE_MKLDNN == 1
+#include <vector>
+#include <utility>
+#include "../activation-inl.h"
+#include "./mkldnn_ops-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param);
+mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
+ const ActivationParam& param, bool is_train,
+ const mkldnn::memory &input_mem, int dtype);
+
+class MKLDNNActForward {
+ public:
+ const mkldnn::eltwise_forward::primitive_desc fwd_pd;
+
+ MKLDNNActForward(const ActivationParam& param, bool is_train,
+ const NDArray &data, const mkldnn::memory &mem): fwd_pd(
+ GetActFwdDescImpl(param, is_train, mem, data.dtype()))
{}
+ void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output);
+ const mkldnn::eltwise_forward &GetFwd() const;
+
+ private:
+ std::shared_ptr<mkldnn::eltwise_forward> fwd_;
+ std::shared_ptr<mkldnn::memory> data_;
+ std::shared_ptr<mkldnn::memory> out_;
+};
+
+typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
+MKLDNNActForward &GetActForward(const ActivationParam& param,
+ const OpContext &ctx, const NDArray &in_data,
+ const mkldnn::memory &in_mem);
+
+void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext
&ctx,
+ const NDArray &in_data, const OpReqType &req,
+ const NDArray &out_data);
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc
b/src/operator/nn/mkldnn/mkldnn_act.cc
index 8c64888..9ce27fa 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -32,8 +32,7 @@
#include <string>
#include <utility>
#include "../../operator_common.h"
-#include "../activation-inl.h"
-#include "./mkldnn_base-inl.h"
+#include "mkldnn_act-inl.h"
#if MXNET_USE_MKLDNN == 1
@@ -58,7 +57,7 @@ bool SupportMKLDNNAct(const ActivationParam& param, const
NDArray &input) {
return SupportMKLDNNAct(param);
}
-static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param)
{
+mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
return mkldnn::algorithm::eltwise_relu;
@@ -74,9 +73,7 @@ static inline mkldnn::algorithm GetMKLDNNActAlgo(const
ActivationParam& param) {
}
}
-typedef std::shared_ptr<mkldnn::eltwise_forward::primitive_desc>
mkldnn_act_pdesc_ptr;
-
-static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
+mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
const mkldnn::memory &input_mem, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
@@ -84,65 +81,41 @@ static mkldnn::eltwise_forward::primitive_desc
GetActFwdDescImpl(
auto cpu_engine = data_mpd.get_engine();
auto alg = GetMKLDNNActAlgo(param);
- MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
- DType alpha = 0;
- mkldnn::eltwise_forward::desc desc = is_train
- ? mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_training,
- alg, data_md, alpha)
- : mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_scoring,
- alg, data_md, alpha);
- return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
- });
- LOG(FATAL) << "Unsupported data type for MKLDNN activation";
- mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
- mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
+
+ auto prop = is_train ? mkldnn::prop_kind::forward_training :
+ mkldnn::prop_kind::forward_scoring;
+ auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
}
-typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
-
-class MKLDNNActForward {
- std::shared_ptr<mkldnn::eltwise_forward> fwd;
- std::shared_ptr<mkldnn::memory> data;
- std::shared_ptr<mkldnn::memory> out;
-
- public:
- const mkldnn::eltwise_forward::primitive_desc fwd_pd;
-
- MKLDNNActForward(const ActivationParam& param, bool is_train,
- const NDArray &data, const mkldnn::memory &mem): fwd_pd(
- GetActFwdDescImpl(param, is_train, mem, data.dtype())) {
- }
-
- void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
- if (this->data == nullptr)
- this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- data.get_primitive_desc(), data.get_data_handle()));
- else
- this->data->set_data_handle(data.get_data_handle());
-
- CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
- if (this->out == nullptr)
- this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.dst_primitive_desc(), output.get_data_handle()));
- else
- this->out->set_data_handle(output.get_data_handle());
-
- if (this->fwd == nullptr) {
- this->fwd = std::shared_ptr<mkldnn::eltwise_forward>(
- new mkldnn::eltwise_forward(fwd_pd,
mkldnn::primitive::at(*this->data),
- *this->out));
- }
+void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const
mkldnn::memory &output) {
+ if (this->data_ == nullptr)
+ this->data_ = std::make_shared<mkldnn::memory>(data.get_primitive_desc(),
+ data.get_data_handle());
+ else
+ this->data_->set_data_handle(data.get_data_handle());
+
+ CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
+ if (this->out_ == nullptr)
+ this->out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(),
+ output.get_data_handle());
+ else
+ this->out_->set_data_handle(output.get_data_handle());
+
+ if (this->fwd_ == nullptr) {
+ this->fwd_ = std::shared_ptr<mkldnn::eltwise_forward>(
+ new mkldnn::eltwise_forward(fwd_pd,
mkldnn::primitive::at(*this->data_),
+ *this->out_));
}
+}
- const mkldnn::eltwise_forward &GetFwd() const {
- return *fwd;
- }
-};
+const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
+ return *fwd_;
+}
-static MKLDNNActForward &GetActForward(const ActivationParam& param,
- const OpContext &ctx, const NDArray
&in_data,
- const mkldnn::memory &in_mem) {
+MKLDNNActForward &GetActForward(const ActivationParam& param,
+ const OpContext &ctx, const NDArray &in_data,
+ const mkldnn::memory &in_mem) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActForward,
OpHash> fwds;
#else
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
new file mode 100644
index 0000000..bc69cb5
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_quantized_act.cc
+ * \brief MKLDNN(Quantized) Activation operator based on subgraph
+ * /author Zhiyuan Huang
+*/
+#if MXNET_USE_MKLDNN == 1
+
+#include "../../nn/mkldnn/mkldnn_act-inl.h"
+#include "../quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+static void MKLDNNQuantizedActForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data) {
+ CHECK(in_data[0].dtype() == mshadow::kUint8 ||
+ in_data[0].dtype() == mshadow::kInt8)
+ << "_contrib_quantized_act op only supports uint8 and int8 as input "
+ "type";
+
+ MKLDNNActivationForward(attrs, ctx, in_data[0], req[0], out_data[0]);
+ out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
+ out_data[2].data().dptr<float>()[0] = in_data[2].data().dptr<float>()[0];
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_act)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedActForward);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/quantize_graph_pass.cc
b/src/operator/quantization/quantize_graph_pass.cc
index 5bd9e8a..7ff2999 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -89,11 +89,12 @@ std::vector<NodeEntry>
OfflineParams(std::vector<NodeEntry>&& outputs,
return outputs;
}
-inline bool NeedQuantize(const NodePtr node,
- const std::unordered_set<std::string>&
excluded_nodes) {
+inline NodePtr NeedQuantize(NodePtr node, const
std::unordered_set<std::string>& excluded_nodes) {
+ std::unordered_map<NodePtr, NodePtr> quantized_node;
static auto& quantized_op_map =
Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
const auto& op = node->op();
+
if (op && quantized_op_map.count(op)) {
bool need = true;
if (excluded_nodes.count(node->attrs.name)) {
@@ -112,14 +113,24 @@ inline bool NeedQuantize(const NodePtr node,
});
}
}
- return need;
+
+ if (need) {
+ auto n_ptr = quantized_op_map[node->op()];
+ auto tmp_node = n_ptr(node->attrs);
+ if (tmp_node->op()) {
+ quantized_node[node] = tmp_node;
+ } else {
+ quantized_node[node] = nullptr;
+ }
+ } else {
+ quantized_node[node] = nullptr;
+ }
}
- return false;
+ return quantized_node[node];
}
Graph QuantizeGraph(Graph &&src) {
static const auto& flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
- static const auto& quantized_op_map =
Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
static const auto& need_requantize_map =
Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
static const auto& avoid_quantize_input_map =
Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
@@ -136,11 +147,9 @@ Graph QuantizeGraph(Graph &&src) {
NodePtr new_node = Node::Create();
// If the currently visited node needs quantization, insert a quantize op
node before the
// current node and replace the current node with the quantized version in
the new graph.
- if (NeedQuantize(node, excluded_nodes)) {
- auto fquantized_op = quantized_op_map[node->op()];
- // If the currently visited node's op registered the FQuantizedOp
property, new_node is a
- // quantizated version of a that op, such as quantized_conv2d.
- new_node = fquantized_op(node->attrs);
+ auto tmp_node = NeedQuantize(node, excluded_nodes);
+ if (tmp_node) {
+ new_node = tmp_node;
// add data into quantized op input
for (size_t i = 0; i < node->inputs.size(); ++i) {
diff --git a/src/operator/quantization/quantized_activation.cc
b/src/operator/quantization/quantized_activation.cc
new file mode 100644
index 0000000..4ab74d0
--- /dev/null
+++ b/src/operator/quantization/quantized_activation.cc
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file quantized_activation.cc
+*/
+#include <mxnet/op_attr_types.h>
+#include "../nn/activation-inl.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+bool QuantizedActivationShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_shape,
+ std::vector<TShape> *out_shape) {
+ CHECK_EQ(in_shape->size(), 3U);
+ if (shape_is_none(in_shape->at(0))) return false;
+ SHAPE_ASSIGN_CHECK(*in_shape, 1, TShape{1});
+ SHAPE_ASSIGN_CHECK(*in_shape, 2, TShape{1});
+ out_shape->clear();
+ out_shape->push_back((*in_shape)[0]);
+ out_shape->push_back(TShape{1});
+ out_shape->push_back(TShape{1});
+ return true;
+}
+
+bool QuantizedActivationType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_type,
+ std::vector<int> *out_type) {
+ const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+ CHECK_EQ(in_type->size(), 3U);
+ CHECK_EQ(out_type->size(), 3U);
+ if (param.act_type == activation::kReLU) {
+ TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt8);
+ } else {
+ LOG(FATAL) << "_contrib_quantized_act only supports act_type=relu for now";
+ }
+ TYPE_ASSIGN_CHECK(*in_type, 1, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_type, 2, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32);
+ return true;
+}
+
+inline static bool QuantizedActivationStorageType(const nnvm::NodeAttrs &attrs,
+ const int dev_mask,
+ DispatchMode *dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs)
{
+ CHECK_EQ(in_attrs->size(), 3);
+
+ *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+ const ActivationParam ¶m = nnvm::get<ActivationParam>(attrs.parsed);
+ if (dev_mask == mshadow::cpu::kDevMask && param.act_type ==
activation::kReLU) {
+ *dispatch_mode = DispatchMode::kFComputeEx;
+ }
+#else
+ CHECK_EQ(out_attrs->size(), 3);
+#endif
+ for (int& out_attr : *out_attrs)
+ out_attr = kDefaultStorage;
+ return true;
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_act)
+.describe(R"code(Activation operator for input and output data type of int8.
+The input and output data comes with min and max thresholds for quantizing
+the float32 data into int8.
+
+.. Note::
+ This operator only supports forward propogation. DO NOT use it in
training.
+ This operator only supports `relu`)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<ActivationParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data", "min_data", "max_data"};
+ })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"output", "min_output", "max_output"};
+ })
+.set_attr<nnvm::FInferType>("FInferType", QuantizedActivationType)
+.set_attr<mxnet::FInferShape>("FInferShape", QuantizedActivationShape)
+.set_attr<FInferStorageType>("FInferStorageType",
QuantizedActivationStorageType)
+.set_attr<FNeedRequantize>("FNeedRequantize",
+ [](const NodeAttrs& attrs) {
+ const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+ CHECK(param.act_type == activation::kReLU)
+ << "_contrib_quantized_act only supports act_type=relu for now";
+ return false;
+ })
+.add_argument("data", "NDArray-or-Symbol", "Input data.")
+.add_argument("min_data", "NDArray-or-Symbol", "Minimum value of data.")
+.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.")
+.add_arguments(ActivationParam::__FIELDS__());
+
+NNVM_REGISTER_OP(Activation)
+.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
+ ActivationParam param;
+ param.Init(attrs.dict);
+ nnvm::NodePtr node = nnvm::Node::Create();
+ if (param.act_type == activation::kReLU) {
+ node->attrs.op = Op::Get("_contrib_quantized_act");
+ node->attrs.name = "quantized_" + attrs.name;
+ } else {
+ node->attrs.op = nullptr;
+ node->attrs.name = attrs.name;
+ }
+ node->attrs.dict = attrs.dict;
+ if (node->op()->attr_parser != nullptr) {
+ node->op()->attr_parser(&(node->attrs));
+ }
+ return node;
+});
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/quantization/test_quantization.py
b/tests/python/quantization/test_quantization.py
index 757df81..2761e77 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -414,6 +414,57 @@ def test_quantized_flatten():
check_quantized_flatten((10, 15, 18), qdtype)
check_quantized_flatten((3, 4, 23, 23), qdtype)
+@with_seed()
+def test_quantized_act():
+ def check_quantized_act(data_shape, qdtype):
+ if is_test_for_native_cpu():
+ print('skipped testing quantized_act for native cpu since it is
not supported yet')
+ return
+ elif qdtype == 'int8' and is_test_for_mkldnn():
+ print('skipped testing quantized_act for mkldnn cpu int8 since it
is not supported yet')
+ return
+ elif is_test_for_gpu():
+ print('skipped testing quantized_act for gpu since it is not
supported yet')
+ return
+ data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
+ act_fp32 = mx.sym.Activation(data=data, act_type='relu', name='relu')
+ arg_shapes, _, _ = act_fp32.infer_shape(data=data_shape)
+ arg_names = act_fp32.list_arguments()
+ act_fp32_exe = act_fp32.simple_bind(ctx=mx.current_context(),
grad_req='null')
+ if qdtype == 'uint8':
+ data_low = 0.0
+ data_high = 127.0
+ else:
+ data_low = -127.0
+ data_high = 127.0
+
+ act_fp32_exe.arg_dict[arg_names[0]][:] =
mx.nd.random.uniform(low=data_low,
+ high=data_high,
shape=data_shape).astype(qdtype)
+ output = act_fp32_exe.forward()[0]
+
+ qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
+ min_data = mx.sym.Variable(name='min_data')
+ max_data = mx.sym.Variable(name='max_data')
+ quantized_act = mx.sym.contrib.quantized_act(data=qdata,
min_data=min_data, max_data=max_data, act_type='relu')
+ act_int8_exe = quantized_act.simple_bind(ctx=mx.current_context(),
grad_req='null')
+ qarg_names = quantized_act.list_arguments()
+
+ act_int8_exe.arg_dict[qarg_names[0]][:] =
act_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
+ quantized_range_min =
mx.nd.min(act_int8_exe.arg_dict[qarg_names[0]][:])
+ quantized_range_max =
mx.nd.max(act_int8_exe.arg_dict[qarg_names[0]][:])
+ act_int8_exe.arg_dict[qarg_names[1]][:] =
quantized_range_min.astype(qdtype)
+ act_int8_exe.arg_dict[qarg_names[2]][:] =
quantized_range_max.astype(qdtype)
+ qoutput, min_range, max_range = act_int8_exe.forward()
+
+ assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
+ assert_almost_equal(min_range.asscalar(),
quantized_range_min.asscalar())
+ assert_almost_equal(max_range.asscalar(),
quantized_range_max.asscalar())
+
+ for qdtype in ['int8', 'uint8']:
+ check_quantized_act((10,), qdtype)
+ check_quantized_act((10, 15), qdtype)
+ check_quantized_act((10, 15, 18), qdtype)
+ check_quantized_act((3, 4, 23, 23), qdtype)
@with_seed()
def test_quantize_params():
@@ -634,7 +685,9 @@ def test_quantize_model_with_forward():
arg_params, aux_params = mod.get_params()
excluded_names = []
if mx.current_context() == mx.cpu():
- excluded_names += ['fc']
+ excluded_names += ['fc', 'conv1']
+ if mx.current_context() == mx.gpu():
+ excluded_names += ['relu0', 'relu1']
excluded_names += ['concat']
optional_names = ['pool0']