This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new bb919783b5 [master] Unified oneDNN implementation calls for quantized
operators (#20987)
bb919783b5 is described below
commit bb919783b5d1f7023284b5513a2615a085b76575
Author: PiotrWolinski - Intel <[email protected]>
AuthorDate: Wed Apr 27 13:14:17 2022 +0200
[master] Unified oneDNN implementation calls for quantized operators
(#20987)
* Done unifying for requantize, quantize and quantized_fully_connected
* Unified oneDNN implementation calls for FComputeEx
* Done unifying for requantize, quantize and quantized_fully_connected
* Unified oneDNN implementation calls for FComputeEx
* Changed dnnl_quantize operator to avoid duplicated reorder
* Added linting
* Added SupportDNNLQuantize function
---
src/operator/nn/dnnl/dnnl_base-inl.h | 4 ++++
src/operator/quantization/dnnl/dnnl_quantize-inl.h | 3 ---
src/operator/quantization/quantize.cc | 20 +++++++++++++++++++-
.../quantization/quantized_fully_connected.cc | 8 +++++---
src/operator/quantization/requantize.cc | 22 +++++++++++++++++++++-
5 files changed, 49 insertions(+), 8 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 8e3d4835d1..c38895bb1e 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -146,6 +146,10 @@ static inline bool SupportDNNL(const NDArray& input) {
return SupportDNNL(input.dtype(), input.shape()) &&
SupportStorageDNNL(input.storage_type());
}
+static inline bool SupportDNNLQuantize(const int out_type) {
+ return out_type == mshadow::kUint8 || out_type == mshadow::kInt8;
+}
+
static inline bool DNNLEnvSet() {
static bool is_dnnl_enabled = dmlc::GetEnv("MXNET_ONEDNN_ENABLED", true);
return is_dnnl_enabled;
diff --git a/src/operator/quantization/dnnl/dnnl_quantize-inl.h
b/src/operator/quantization/dnnl/dnnl_quantize-inl.h
index 56fa3152d1..1020603da5 100644
--- a/src/operator/quantization/dnnl/dnnl_quantize-inl.h
+++ b/src/operator/quantization/dnnl/dnnl_quantize-inl.h
@@ -67,9 +67,6 @@ static void DNNLQuantizeComputeKer(const
std::vector<NDArray>& inputs,
attr.set_output_scales(mask, scales);
dnnl::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
NDArray in_buffer = inputs[0];
- if (inputs[0].IsView() && inputs[0].IsDNNLData())
- in_buffer = inputs[0].Reorder2Default();
-
auto i_mem = in_buffer.GetDNNLData();
auto i_desc = i_mem->get_desc();
size_t i_ndim = in_buffer.shape().ndim();
diff --git a/src/operator/quantization/quantize.cc
b/src/operator/quantization/quantize.cc
index 2b42abdca8..f918dcb4e5 100644
--- a/src/operator/quantization/quantize.cc
+++ b/src/operator/quantization/quantize.cc
@@ -47,6 +47,24 @@ bool QuantizeStorageType(const nnvm::NodeAttrs& attrs,
return true;
}
+#if MXNET_USE_ONEDNN == 1
+static void QuantizeComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
+
+ if (SupportDNNLQuantize(param.out_type)) {
+ DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLQuantizeCompute, attrs, ctx, inputs, req, outputs);
+ DNNL_OPCHECK_RUN(QuantizeCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ }
+ FallBackCompute(QuantizeCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+#endif
+
NNVM_REGISTER_OP(_contrib_quantize)
.add_alias("_npx_contrib_quantize")
.describe(R"code(Quantize a input tensor from float to `out_type`,
@@ -88,7 +106,7 @@ where
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsDNNL", true)
- .set_attr<FComputeEx>("FComputeEx<cpu>", DNNLQuantizeCompute)
+ .set_attr<FComputeEx>("FComputeEx<cpu>", QuantizeComputeExCPU)
#endif
.set_attr<FCompute>("FCompute<cpu>", QuantizeCompute<cpu>)
.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type
`float32`")
diff --git a/src/operator/quantization/quantized_fully_connected.cc
b/src/operator/quantization/quantized_fully_connected.cc
index 930816abef..5d4bdfd8e0 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -305,10 +305,12 @@ void QuantizedFullyConnectedForwardCPU(const
nnvm::NodeAttrs& attrs,
#if MXNET_USE_ONEDNN == 1
void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
- const std::vector<NDArray>& in_data,
+ const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
- const std::vector<NDArray>& out_data)
{
- DNNLQuantizedFullyConnectedForward(attrs, ctx, in_data, req, out_data);
+ const std::vector<NDArray>& outputs) {
+ DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLQuantizedFullyConnectedForward, attrs, ctx, inputs, req,
outputs);
+ DNNL_OPCHECK_RUN(QuantizedFullyConnectedForwardCPU, attrs, ctx, inputs, req,
outputs);
}
#endif
diff --git a/src/operator/quantization/requantize.cc
b/src/operator/quantization/requantize.cc
index c16c3778ac..0faf2e7a4b 100644
--- a/src/operator/quantization/requantize.cc
+++ b/src/operator/quantization/requantize.cc
@@ -29,6 +29,26 @@
namespace mxnet {
namespace op {
+
+#if MXNET_USE_ONEDNN == 1
+void RequantizeForwardExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ const RequantizeParam& param = nnvm::get<RequantizeParam>(attrs.parsed);
+ auto out_type = GetQuantizeOutputType(param);
+
+ if (SupportDNNLQuantize(out_type)) {
+ DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLRequantizeForward, attrs, ctx, inputs, req, outputs);
+ DNNL_OPCHECK_RUN(RequantizeForward<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ }
+ FallBackCompute(RequantizeForward<cpu>, attrs, ctx, inputs, req, outputs);
+}
+#endif
+
DMLC_REGISTER_PARAMETER(RequantizeParam);
bool RequantizeStorageType(const nnvm::NodeAttrs& attrs,
@@ -74,7 +94,7 @@ inference accuracy.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsDNNL", true)
- .set_attr<FComputeEx>("FComputeEx<cpu>", DNNLRequantizeForward)
+ .set_attr<FComputeEx>("FComputeEx<cpu>", RequantizeForwardExCPU)
#else
.set_attr<FCompute>("FCompute<cpu>", RequantizeForward<cpu>)
#endif