This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 5abdc77f3c [FEATURE] Add _npi_power_scalar and _npi_multiply_scalar
fuse (#20976)
5abdc77f3c is described below
commit 5abdc77f3c191bc771a5e97a9cecf30fd832de96
Author: bartekkuncer <[email protected]>
AuthorDate: Tue Jul 5 09:39:50 2022 +0200
[FEATURE] Add _npi_power_scalar and _npi_multiply_scalar fuse (#20976)
* [FEATURE] Add _npi_power_scalar and _npi_multiply_scalar fuse
* Merge _npi_power_scalar implementation with implementation of this fuse
* Fix clang
* Fix CI
* Fix review and simplify the implementation
* Add checks for the amount of inputs and outputs
* Fix CI
* Add Reset() function
* Fix DNNLPowMulScalarShape and Type functions
* Fix DNNLPowMulScalarType
* Fix DNNLPowMulScalarType
* Add generic implementation for sq_pow_mul_scalar operator
* Fix sanity
* Fix req
* Add Filter method to property
* Add new line
* Fix gpu CI
* Add '_sg_pow_mul_scalar' to symbol_fp16.py
* Fix CI on MacOS
* Fix SupportDNNL*
* Make PowMulScalarCompute more readable
* Fix PowMulScalarCompute
* Fix memory usage
* Fix build
---
include/mxnet/op_attr_types.h | 2 +-
include/mxnet/tensor_blob.h | 4 +-
python/mxnet/amp/lists/symbol_fp16.py | 3 +-
src/operator/nn/dnnl/dnnl_base-inl.h | 1 -
src/operator/nn/dnnl/dnnl_pow_mul_scalar-inl.h | 100 ++++++++++
...dnnl_power_scalar.cc => dnnl_pow_mul_scalar.cc} | 65 +++----
src/operator/nn/dnnl/dnnl_power_scalar-inl.h | 66 -------
src/operator/subgraph/common.h | 2 +
src/operator/subgraph/dnnl/dnnl_bn_relu_property.h | 2 -
src/operator/subgraph/dnnl/dnnl_conv_property.h | 4 +-
src/operator/subgraph/dnnl/dnnl_fc_property.h | 8 -
src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h | 7 -
.../subgraph/dnnl/dnnl_post_quantize_property.h | 26 +--
src/operator/subgraph/dnnl/dnnl_pow_mul_scalar.cc | 208 +++++++++++++++++++++
.../subgraph/dnnl/dnnl_pow_mul_scalar_property.h | 126 +++++++++++++
.../subgraph/dnnl/dnnl_subgraph_property.cc | 2 +
.../subgraph/dnnl/dnnl_transformer_qk_property.h | 4 +-
src/operator/subgraph/subgraph_property.h | 2 +-
src/operator/tensor/elemwise_binary_scalar_op.h | 3 -
.../tensor/elemwise_binary_scalar_op_extended.cc | 11 +-
tests/python/dnnl/subgraphs/subgraph_common.py | 10 +-
tests/python/dnnl/subgraphs/test_fc_subgraph.py | 2 +-
.../python/dnnl/subgraphs/test_pow_mul_subgraph.py | 41 ++++
23 files changed, 547 insertions(+), 152 deletions(-)
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 0bc2a8f62d..01a2ad36b8 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -58,7 +58,7 @@ enum OpReqType {
};
/*!
- * \brief All the possible information needed by Operator.Forward and Backward
+ * \brief All the possible information needed by Operator.
* This is the superset of RunContext.
* We use this data structure to bookkeep everything needed by Forward and
Backward.
* \sa Resource
diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index 479b3cf3a2..dc265e4d7a 100644
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -210,7 +210,7 @@ class TBlob {
CHECK(Device::kDevMask == this->dev_mask())
<< "TBlob.get: device type do not match specified type";
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
- << "TBlob.get_with_shape: data type do not match specified type."
+ << "TBlob.get_with_shape: data type do not match specified type. "
<< "Expected: " << mshadow::dtype_string(type_flag_) << " v.s. given "
<< mshadow::dtype_string(mshadow::DataType<DType>::kFlag);
return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
shape_.FlatTo2D(), stream);
@@ -248,7 +248,7 @@ class TBlob {
template <typename DType>
inline DType* dptr() const {
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
- << "TBlob.get_with_shape: data type do not match specified type."
+ << "TBlob.get_with_shape: data type do not match specified type. "
<< "Expected: " << mshadow::dtype_string(type_flag_) << " v.s. given "
<< mshadow::dtype_string(mshadow::DataType<DType>::kFlag);
return static_cast<DType*>(dptr_);
diff --git a/python/mxnet/amp/lists/symbol_fp16.py
b/python/mxnet/amp/lists/symbol_fp16.py
index ad1f0ad4b2..1cd5316361 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -636,7 +636,8 @@ if Features().is_enabled('ONEDNN'):
'_sg_onednn_fully_connected',
'_sg_onednn_selfatt_qk',
'_sg_onednn_selfatt_valatt',
- '_sg_onednn_batch_dot'
+ '_sg_onednn_batch_dot',
+ '_sg_pow_mul_scalar'
])
# Functions that have to be cast to FP32 only for
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h
b/src/operator/nn/dnnl/dnnl_base-inl.h
index e8cf9c44b3..66c1dc2c99 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -254,7 +254,6 @@ bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input);
bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const
std::vector<NDArray>& input);
-bool SupportDNNLPower(const NDArray& input);
bool SupportDNNLQuantizedAct(const ActivationParam& param);
bool SupportDNNLReshape(const NDArray& input);
bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const
NDArray& output);
diff --git a/src/operator/nn/dnnl/dnnl_pow_mul_scalar-inl.h
b/src/operator/nn/dnnl/dnnl_pow_mul_scalar-inl.h
new file mode 100644
index 0000000000..7f8cd5c5bd
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_pow_mul_scalar-inl.h
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_pow_mul_scalar-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_POW_MUL_SCALAR_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_POW_MUL_SCALAR_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <vector>
+
+#include "operator/tensor/elemwise_binary_scalar_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct DNNLPowMulScalarParam : public dmlc::Parameter<DNNLPowMulScalarParam> {
+ float exponent;
+ float multiplier;
+ bool exp_is_int;
+ bool mul_is_int;
+
+ DMLC_DECLARE_PARAMETER(DNNLPowMulScalarParam) {
+ DMLC_DECLARE_FIELD(exponent).describe("Exponent for power
operation.").set_default(1);
+ DMLC_DECLARE_FIELD(multiplier).describe("Multiplier for multiply
operation.").set_default(1);
+ DMLC_DECLARE_FIELD(exp_is_int)
+ .describe("Indicate whether exponent is int type.")
+ .set_default(true);
+ DMLC_DECLARE_FIELD(mul_is_int)
+ .describe("Indicate whether multiplier is int type.")
+ .set_default(true);
+ }
+
+ bool operator==(const DNNLPowMulScalarParam& other) const {
+ return this->exponent == other.exponent && this->multiplier ==
other.multiplier &&
+ this->exp_is_int == other.exp_is_int && this->mul_is_int ==
other.mul_is_int;
+ }
+};
+
+using eltwise_fwd_t = dnnl::eltwise_forward;
+using eltwise_fwd_pd_t = dnnl::eltwise_forward::primitive_desc;
+
+typedef ParamOpSign<DNNLPowMulScalarParam> DNNLPowMulScalarSignature;
+
+class DNNLPowMulScalarFwd {
+ public:
+ static DNNLPowMulScalarFwd& GetCached(const DNNLPowMulScalarParam& param,
+ const NDArray& input,
+ const NDArray& output);
+
+ DNNLPowMulScalarFwd(const DNNLPowMulScalarParam& param, const NDArray&
input);
+
+ void Execute(const NDArray& input, const OpReqType& req, const NDArray&
output);
+
+ private:
+ std::shared_ptr<eltwise_fwd_t> fwd;
+ std::shared_ptr<eltwise_fwd_pd_t> fwd_pd;
+};
+
+template <bool subgraph>
+inline void DNNLPowMulScalarForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ DNNLPowMulScalarParam param;
+ if (subgraph) {
+ param = nnvm::get<DNNLPowMulScalarParam>(attrs.parsed);
+ } else {
+ param.multiplier = 1;
+ param.exponent = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed).scalar;
+ }
+ DNNLPowMulScalarFwd& fwd = DNNLPowMulScalarFwd::GetCached(param, inputs[0],
outputs[0]);
+ fwd.Execute(inputs[0], req[0], outputs[0]);
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_POW_MUL_SCALAR_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_power_scalar.cc
b/src/operator/nn/dnnl/dnnl_pow_mul_scalar.cc
similarity index 51%
rename from src/operator/nn/dnnl/dnnl_power_scalar.cc
rename to src/operator/nn/dnnl/dnnl_pow_mul_scalar.cc
index 223ab21c1c..f78440e949 100644
--- a/src/operator/nn/dnnl/dnnl_power_scalar.cc
+++ b/src/operator/nn/dnnl/dnnl_pow_mul_scalar.cc
@@ -18,49 +18,54 @@
*/
/*!
- * \file dnnl_power_scalar.cc
- * \author: Adam Grabowski, [email protected]
+ * \file dnnl_pow_mul_scalar.cc
*/
#if MXNET_USE_ONEDNN == 1
-#include "dnnl_power_scalar-inl.h"
+#include "dnnl_pow_mul_scalar-inl.h"
namespace mxnet {
namespace op {
-DNNLPowerFwd& DNNLPowerFwd::GetPowerForward(const nnvm::NodeAttrs& attrs,
- const NDArray& input,
- const NDArray& output) {
- const NumpyBinaryScalarParam& param =
nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
+DMLC_REGISTER_PARAMETER(DNNLPowMulScalarParam);
+
+DNNLPowMulScalarFwd& DNNLPowMulScalarFwd::GetCached(const
DNNLPowMulScalarParam& param,
+ const NDArray& input,
+ const NDArray& output) {
#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<DNNLPowerSignature, DNNLPowerFwd,
OpHash> fwds;
+ static thread_local std::unordered_map<DNNLPowMulScalarSignature,
DNNLPowMulScalarFwd, OpHash>
+ fwds;
#else
- static MX_THREAD_LOCAL std::unordered_map<DNNLPowerSignature, DNNLPowerFwd,
OpHash> fwds;
+ static MX_THREAD_LOCAL std::unordered_map<DNNLPowMulScalarSignature,
DNNLPowMulScalarFwd, OpHash>
+ fwds;
#endif
- DNNLPowerSignature key;
- key.AddSign(static_cast<float>(param.scalar));
+ DNNLPowMulScalarSignature key(param);
key.AddSign(input);
key.AddSign(output);
auto it = fwds.find(key);
if (it == fwds.end()) {
- const DNNLPowerFwd fwd(input, static_cast<float>(param.scalar));
+ const DNNLPowMulScalarFwd fwd(param, input);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}
-DNNLPowerFwd::DNNLPowerFwd(const NDArray& input, const float exponent) {
+DNNLPowMulScalarFwd::DNNLPowMulScalarFwd(const DNNLPowMulScalarParam& param,
const NDArray& input) {
auto src_desc = input.GetDNNLData()->get_desc();
- dnnl::eltwise_forward::desc fwd_desc(
- dnnl::prop_kind::forward_scoring, dnnl::algorithm::eltwise_pow,
src_desc, 1, exponent);
+ dnnl::eltwise_forward::desc fwd_desc(dnnl::prop_kind::forward_scoring,
+ dnnl::algorithm::eltwise_pow,
+ src_desc,
+ param.multiplier,
+ param.exponent);
fwd_pd = std::make_shared<eltwise_fwd_pd_t>(fwd_desc,
mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<eltwise_fwd_t>(*fwd_pd);
}
-void DNNLPowerFwd::Execute(const NDArray& input, const OpReqType& req, const
NDArray& output) {
- auto engine = mxnet::CpuEngine::Get()->get_engine();
+void DNNLPowMulScalarFwd::Execute(const NDArray& input,
+ const OpReqType& req,
+ const NDArray& output) {
auto src = input.GetDNNLData();
dnnl_output_t out_mem = CreateDNNLMem(output, fwd_pd->dst_desc(), req,
&input);
@@ -73,22 +78,18 @@ void DNNLPowerFwd::Execute(const NDArray& input, const
OpReqType& req, const NDA
CommitOutput(output, out_mem);
DNNLStream::Get()->Submit();
}
-
-void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& input,
- const OpReqType& req,
- const NDArray& output) {
- DNNLPowerFwd& fwd = DNNLPowerFwd::GetPowerForward(attrs, input, output);
- fwd.Execute(input, req, output);
-}
-
-bool SupportDNNLPower(const NDArray& input) {
- return input.shape().Size() != 0 && input.shape().ndim() > 0 &&
input.shape().ndim() <= 6 &&
- input.dtype() == mshadow::kFloat32;
-}
-
} // namespace op
} // namespace mxnet
+namespace std {
+template <>
+struct hash<mxnet::op::DNNLPowMulScalarParam> {
+ size_t operator()(const mxnet::op::DNNLPowMulScalarParam& val) {
+ size_t ret = 0;
+ ret = dmlc::HashCombine(ret, val.exponent);
+ ret = dmlc::HashCombine(ret, val.multiplier);
+ return ret;
+ }
+};
+} // namespace std
#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
b/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
deleted file mode 100644
index 5ece7ef832..0000000000
--- a/src/operator/nn/dnnl/dnnl_power_scalar-inl.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_power_scalar-inl.h
- * \author: Adam Grabowski, [email protected]
- */
-
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_POWER_SCALAR_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_POWER_SCALAR_INL_H_
-
-#if MXNET_USE_ONEDNN == 1
-
-#include "dnnl_base-inl.h"
-#include "operator/tensor/elemwise_binary_scalar_op.h"
-
-namespace mxnet {
-namespace op {
-
-using eltwise_fwd_t = dnnl::eltwise_forward;
-using eltwise_fwd_pd_t = dnnl::eltwise_forward::primitive_desc;
-
-class DNNLPowerFwd {
- public:
- static DNNLPowerFwd& GetPowerForward(const nnvm::NodeAttrs& attrs,
- const NDArray& input,
- const NDArray& outputs);
-
- DNNLPowerFwd(const NDArray& input, const float exponent);
-
- void Execute(const NDArray& input, const OpReqType& req, const NDArray&
output);
-
- private:
- std::shared_ptr<eltwise_fwd_t> fwd;
- std::shared_ptr<eltwise_fwd_pd_t> fwd_pd;
-};
-
-typedef OpSignature DNNLPowerSignature;
-
-void DNNLPowerForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const NDArray& input,
- const OpReqType& req,
- const NDArray& output);
-
-} // namespace op
-} // namespace mxnet
-
-#endif // MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_NN_DNNL_DNNL_POWER_SCALAR_INL_H_
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index c2781faaf2..477be4bb96 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -29,6 +29,8 @@
namespace mxnet {
namespace op {
+enum SelectStatus { kFail = 0, kStart, kSuccess };
+
inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListInputNames(nnvm::Symbol::kAll).size();
diff --git a/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
b/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
index d59b5e29f8..792236a50d 100644
--- a/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
@@ -35,8 +35,6 @@ namespace op {
class SgDNNLBNReLUSelector : public SubgraphSelector {
public:
- enum SelectStatus { kStart, kSuccess, kFail };
-
explicit SgDNNLBNReLUSelector(const bool disable_bn_relu)
: disable_bn_relu_(disable_bn_relu), status_(kStart) {}
diff --git a/src/operator/subgraph/dnnl/dnnl_conv_property.h
b/src/operator/subgraph/dnnl/dnnl_conv_property.h
index d7f90e44c2..814e70cdf5 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_conv_property.h
@@ -37,7 +37,7 @@ namespace op {
class SgDNNLConvSelector : public SubgraphSelector {
public:
/*! \brief pattern match status_ */
- enum SelectStatus {
+ enum SelectStatusConv {
kFail = 0,
kStart,
kBN,
@@ -51,7 +51,7 @@ class SgDNNLConvSelector : public SubgraphSelector {
bool disable_conv_act_;
bool disable_conv_sum_;
bool quantize_;
- SelectStatus status_;
+ SelectStatusConv status_;
std::vector<const nnvm::Node*> matched_list_;
public:
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index 5b567ee652..fd5272ef5a 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -40,14 +40,6 @@ namespace mxnet {
namespace op {
class SgDNNLFCSelector : public SubgraphSelector {
- public:
- /* pattern match status */
- enum SelectStatus {
- kFail = 0,
- kStart,
- kSuccess,
- };
-
private:
bool disable_fc_eltwise_;
bool quantized_;
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
index 6cfce33621..c65711493c 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
@@ -54,13 +54,6 @@ inline bool EndsWith(std::string const& value, std::string
const& ending) {
class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
private:
- /*! \brief pattern match status */
- enum SelectStatus {
- kFail = 0,
- kStart,
- kSuccess,
- };
-
bool quantized_;
SelectStatus status_ = kFail;
std::vector<const BiDirectedNode*> matched_list_;
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 2c21db8ad6..14717592b0 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -55,7 +55,7 @@ bool SupportsRequantizeFusion(const Op* op) {
class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
private:
/*! \brief pattern match status */
- enum class SelectStatus {
+ enum class SelectStatusPostQuantize {
kFail = 0,
kStart,
kRequantize,
@@ -64,7 +64,7 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
bool fuse_all;
bool float_output;
- SelectStatus status;
+ SelectStatusPostQuantize status;
std::vector<const BiDirectedNode*> matched_list;
public:
@@ -75,7 +75,7 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
const nnvm::Node* raw_node = n.node;
if (fuse_all && raw_node->op() &&
SupportsRequantizeFusion(raw_node->op())) {
- status = SelectStatus::kStart;
+ status = SelectStatusPostQuantize::kStart;
matched_list.clear();
matched_list.emplace_back(&n);
return true;
@@ -94,7 +94,7 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
static const std::set<const Op*> dequantize_fusion_unsupported_ops = {
Op::Get("_contrib_quantized_elemwise_add"),
Op::Get("_contrib_quantized_npi_add")};
- if (status == SelectStatus::kFail || status == SelectStatus::kSuccess ||
+ if (status == SelectStatusPostQuantize::kFail || status ==
SelectStatusPostQuantize::kSuccess ||
raw_new_node->is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
@@ -105,26 +105,26 @@ class SgDNNLPostQuantizeSelector : public
SubgraphSelectorV2 {
matched_list.pop_back();
}
}
- status = SelectStatus::kSuccess;
+ status = SelectStatusPostQuantize::kSuccess;
return false;
}
switch (status) {
- case SelectStatus::kStart:
+ case SelectStatusPostQuantize::kStart:
if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
auto const& param =
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
matched_list.emplace_back(&new_node);
- status = SelectStatus::kRequantize;
+ status = SelectStatusPostQuantize::kRequantize;
// For now there is no support for dequantize fusion for some
operators
// so then we finish after finding requantize node:
if (dequantize_fusion_unsupported_ops.count(raw_node->op()) != 0) {
- status = SelectStatus::kSuccess;
+ status = SelectStatusPostQuantize::kSuccess;
}
return true;
}
}
- case SelectStatus::kRequantize:
+ case SelectStatusPostQuantize::kRequantize:
if (float_output && raw_new_node->op() ==
Op::Get("_contrib_dequantize")) {
CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
if (n.outputs.size() > 1) {
@@ -133,23 +133,23 @@ class SgDNNLPostQuantizeSelector : public
SubgraphSelectorV2 {
for (const auto& kv : n.outputs) {
const auto& node = kv.first;
if (node->op() != Op::Get("_contrib_dequantize")) {
- status = SelectStatus::kSuccess;
+ status = SelectStatusPostQuantize::kSuccess;
return false;
}
}
}
matched_list.emplace_back(&new_node);
- status = SelectStatus::kSuccess;
+ status = SelectStatusPostQuantize::kSuccess;
return true;
}
default:
- status = SelectStatus::kSuccess;
+ status = SelectStatusPostQuantize::kSuccess;
return false;
}
}
std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
- if (status != SelectStatus::kSuccess || (matched_list.size() <= 1)) {
+ if (status != SelectStatusPostQuantize::kSuccess || (matched_list.size()
<= 1)) {
return std::vector<BiDirectedNode*>(0);
} else {
std::vector<BiDirectedNode*> ret;
diff --git a/src/operator/subgraph/dnnl/dnnl_pow_mul_scalar.cc
b/src/operator/subgraph/dnnl/dnnl_pow_mul_scalar.cc
new file mode 100644
index 0000000000..09853852e1
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_pow_mul_scalar.cc
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_pow_mul_scalar.cc
+ * \brief DNNL pow_mul_scalar operator based on subgraph
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/nn/dnnl/dnnl_pow_mul_scalar-inl.h"
+#include "operator/subgraph/common.h"
+
+namespace mxnet {
+namespace op {
+bool DNNLPowMulScalarType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const DNNLPowMulScalarParam& param =
nnvm::get<DNNLPowMulScalarParam>(attrs.parsed);
+ bool scalar_is_int = param.exp_is_int && param.mul_is_int;
+ if (common::is_int(in_attrs->at(0)) && !scalar_is_int) {
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64);
+ } else if (in_attrs->at(0) == mshadow::kBool) {
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, scalar_is_int ? mshadow::kInt64 :
mshadow::kFloat64);
+ } else {
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+ }
+ return out_attrs->at(0) != -1;
+}
+
+inline static bool DNNLPowMulScalarStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
+}
+
+template <typename OP>
+static void ComputeOP(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ mshadow::Stream<cpu>* s,
+ const TBlob& input,
+ const TBlob& output,
+ const double scalar,
+ char* extra_mem_handle = nullptr,
+ const size_t extra_mem_offset = 0) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+ auto temp_req = input.dptr_ == output.dptr_ ? kWriteInplace : kWriteTo;
+ TBlob temp_tblob = input;
+ if (input.type_flag_ != output.type_flag_) {
+ auto shape = Shape1(output.Size());
+ temp_tblob = TBlob(mshadow::Tensor<cpu, 1, DType>(
+ reinterpret_cast<DType*>(extra_mem_handle + extra_mem_offset),
shape, shape[0], s));
+ CastCompute<cpu>(attrs, ctx, {input}, {kWriteTo}, {temp_tblob});
+ }
+ MXNET_ASSIGN_REQ_SWITCH(temp_req, Req, {
+ mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
+ s, input.Size(), output.dptr<DType>(), temp_tblob.dptr<DType>(),
DType(scalar));
+ });
+ });
+}
+
+static void PowMulScalarCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+ DCHECK_EQ(inputs.size(), 1);
+ DCHECK_EQ(outputs.size(), 1);
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ const DNNLPowMulScalarParam& param =
nnvm::get<DNNLPowMulScalarParam>(attrs.parsed);
+ // temp_mid_tblob is output of power operation and input of multiplication.
Its dtype depends on
+ // input dtype and scalar type.
+ TBlob temp_mid_tblob;
+ if (inputs[0].type_flag_ == outputs[0].type_flag_) {
+ // If dtype is the same for input and output data there is no need for
additional memory.
+ temp_mid_tblob = outputs[0];
+ ComputeOP<mshadow_op::power>(attrs, ctx, s, inputs[0], temp_mid_tblob,
param.exponent);
+ ComputeOP<mshadow_op::mul>(attrs, ctx, s, temp_mid_tblob, outputs[0],
param.multiplier);
+ } else {
+ // If input dtype is different than the output dtype we can be sure that
input is of integer or
+ // bool dtype and there will be some additional memory needed for
temp_mid_tblob and Cast
+ // operations.
+ char* extra_mem_handle = nullptr;
+ size_t exp_mem_offset = 0; // Memory offset for eventual Cast
operation in power operation.
+ size_t mul_mem_offset = 0; // Memory offset for eventual Cast
operation in mul operation.
+ size_t extra_mem_size = 0;
+ const size_t tensor_size = inputs[0].Size();
+ const auto shape = Shape1(tensor_size);
+ if (!param.exp_is_int) {
+ // If exponent is not an integer output of both power and multiplication
operations is of same
+ // dtype as outputs[0]. Extra memory space is needed only for Cast in
power operation.
+ temp_mid_tblob = outputs[0];
+ extra_mem_size = tensor_size * sizeof(double);
+ extra_mem_handle =
+
reinterpret_cast<char*>(ctx.requested[0].get_space_internal(extra_mem_size,
"PowMul"));
+ } else {
+ if (!param.mul_is_int) {
+ // Extra memory space needed for Cast in mul operation.
+ extra_mem_size = tensor_size * sizeof(double);
+ }
+ if (inputs[0].type_flag_ == kBool) {
+ // If exponent is an integer and input data is of bool dtype, output
of the power operation
+ // is of int64 dtype. Extra memory needed for temp_mid_tblob and Cast
in power operation.
+ exp_mem_offset = tensor_size * sizeof(int64_t);
+ mul_mem_offset = tensor_size * sizeof(int64_t) * 2;
+ extra_mem_size += mul_mem_offset;
+ extra_mem_handle =
+
reinterpret_cast<char*>(ctx.requested[0].get_space_internal(extra_mem_size,
"PowMul"));
+ temp_mid_tblob = TBlob(mshadow::Tensor<cpu, 1, int64_t>(
+ reinterpret_cast<int64_t*>(extra_mem_handle), shape, shape[0], s));
+ } else {
+ // If both exponent and input data is of integer dtype, output of the
power operation is of
+ // the same dtype as its input. Extra memory needed for temp_mid_tblob.
+ MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ mul_mem_offset = tensor_size * sizeof(DType);
+ extra_mem_size += mul_mem_offset;
+ extra_mem_handle = reinterpret_cast<char*>(
+ ctx.requested[0].get_space_internal(extra_mem_size, "PowMul"));
+ temp_mid_tblob = TBlob(mshadow::Tensor<cpu, 1, DType>(
+ reinterpret_cast<DType*>(extra_mem_handle), shape, shape[0], s));
+ });
+ }
+ }
+ ComputeOP<mshadow_op::power>(
+ attrs, ctx, s, inputs[0], temp_mid_tblob, param.exponent,
extra_mem_handle, exp_mem_offset);
+ ComputeOP<mshadow_op::mul>(attrs,
+ ctx,
+ s,
+ temp_mid_tblob,
+ outputs[0],
+ param.multiplier,
+ extra_mem_handle,
+ mul_mem_offset);
+ }
+}
+
+static void PowMulScalarComputeEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<mxnet::NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<mxnet::NDArray>& outputs) {
+ if (SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[0])) {
+ DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLPowMulScalarForward<true>, attrs, ctx, inputs, req, outputs);
+ DNNL_OPCHECK_RUN(PowMulScalarCompute, attrs, ctx, inputs, req, outputs);
+ } else {
+ FallBackCompute(PowMulScalarCompute, attrs, ctx, inputs, req, outputs);
+ }
+}
+
+NNVM_REGISTER_OP(_sg_pow_mul_scalar)
+ .describe(R"code(_sg_pow_mul_scalar)code" ADD_FILELINE)
+ .set_num_inputs([](const NodeAttrs& attrs) { return 1; })
+ .set_num_outputs([](const NodeAttrs& attrs) { return 1; })
+ .set_attr_parser(ParamParser<DNNLPowMulScalarParam>)
+ .set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return
std::vector<std::string>{"input"};
+ })
+ .set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const NodeAttrs& attrs) {
+ return
std::vector<std::string>{"output"};
+ })
+ .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+ .set_attr<nnvm::FInferType>("FInferType", DNNLPowMulScalarType)
+ .set_attr<FInferStorageType>("FInferStorageType",
DNNLPowMulScalarStorageType)
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return
std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+ .set_attr<FCompute>("FCompute<cpu>", PowMulScalarCompute)
+ .set_attr<FComputeEx>("FComputeEx<cpu>", PowMulScalarComputeEx)
+ .set_attr<bool>("TIsDNNL", true)
+ .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_ONEDNN == 1
diff --git a/src/operator/subgraph/dnnl/dnnl_pow_mul_scalar_property.h
b/src/operator/subgraph/dnnl/dnnl_pow_mul_scalar_property.h
new file mode 100644
index 0000000000..cc310ab27e
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_pow_mul_scalar_property.h
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_pow_mul_scalar_property.h
+ * \brief Graph property for fusing _npi_power_scalar with _npi_multiply_scalar
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "operator/subgraph/common.h"
+#include "operator/tensor/elemwise_binary_scalar_op.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgDNNLPowMulScalarSelector : public SubgraphSelectorV2 {
+ private:
+ bool patternFound = false;
+
+ public:
+ bool Select(const BiDirectedNode& seed_node,
+ const std::shared_ptr<NodeAttr>& node_attr) override {
+ if (seed_node.node->op() == Op::Get("_npi_power_scalar") &&
+ seed_node.node->num_outputs() == 1) {
+ return true;
+ }
+ return false;
+ }
+
+ bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node)
override {
+ return false;
+ }
+
+ bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode&
output_node) override {
+ if (!patternFound && output_node.node->op() ==
Op::Get("_npi_multiply_scalar") &&
+ output_node.node->num_inputs() == 1) {
+ patternFound = true;
+ return true;
+ }
+ return false;
+ }
+
+ std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
+ return patternFound ? candidates : std::vector<BiDirectedNode*>(0);
+ }
+
+ void Reset() override {
+ patternFound = false;
+ }
+};
+
+class SgDNNLPowMulScalarProperty : public SubgraphProperty {
+ public:
+ SgDNNLPowMulScalarProperty() {}
+
+ static SubgraphPropertyPtr Create() {
+ static const std::string& name = "DNNL PowMulScalar optimization pass";
+ auto property =
std::make_shared<SgDNNLPowMulScalarProperty>();
+ property->SetAttr<std::string>("property_name", name);
+ property->SetAttr<bool>("inference_only", true);
+ if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_POW_MUL_SCALAR_OPT", 0)) {
+ property->SetAttr<bool>("disable", true);
+ }
+ return property;
+ }
+
+ nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+ const int subgraph_id = 0) const override
{
+ nnvm::ObjectPtr n = nnvm::Node::Create();
+
+ DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
+ if (node->is_variable())
+ return;
+ if (node->op() == Op::Get("_npi_power_scalar")) {
+ auto params =
nnvm::get<NumpyBinaryScalarParam>(node->attrs.parsed);
+ n->attrs.dict["exponent"] = std::to_string(params.scalar);
+ n->attrs.dict["exp_is_int"] = std::to_string(params.is_int);
+ } else if (node->op() == Op::Get("_npi_multiply_scalar")) {
+ auto params =
nnvm::get<NumpyBinaryScalarParam>(node->attrs.parsed);
+ n->attrs.dict["multiplier"] = std::to_string(params.scalar);
+ n->attrs.dict["mul_is_int"] = std::to_string(params.is_int);
+ }
+ });
+
+ n->attrs.name = "_sg_pow_mul_scalar" + std::to_string(subgraph_id);
+ n->attrs.op = Op::Get("_sg_pow_mul_scalar");
+ CHECK(n->attrs.op);
+ n->op()->attr_parser(&(n->attrs));
+ return n;
+ }
+
+ SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+ auto selector = std::make_shared<SgDNNLPowMulScalarSelector>();
+ return selector;
+ }
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index a7a290f93f..69fae1c97d 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -27,6 +27,7 @@
#include "dnnl_post_amp_property.h"
#include "dnnl_post_quantize_align_scale_property.h"
#include "dnnl_post_quantize_property.h"
+#include "dnnl_pow_mul_scalar_property.h"
#include "dnnl_transformer_qk_property.h"
#include "dnnl_transformer_valatt_property.h"
#include "dnnl_fc_sum_fuse.h"
@@ -47,6 +48,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN,
SgDNNLRemoveCastsProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLPowMulScalarProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCSumFuseProperty);
MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context",
Context::CPU());
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
index fc14df37e2..a70be63283 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
@@ -52,7 +52,7 @@ namespace mxnet {
namespace op {
class SgDNNLTransformerQKSelector : public SubgraphSelectorV2 {
- enum SelectStatus {
+ enum SelectStatusTransformerQK {
kFail = 0,
kStart,
kFirstSwapAx,
@@ -68,7 +68,7 @@ class SgDNNLTransformerQKSelector : public SubgraphSelectorV2
{
*/
private:
- SelectStatus status_;
+ SelectStatusTransformerQK status_;
std::vector<const BiDirectedNode*> matched_list_;
bool CheckSplitConditions(const BiDirectedNode& node) {
diff --git a/src/operator/subgraph/subgraph_property.h
b/src/operator/subgraph/subgraph_property.h
index 007ac421eb..563516d51e 100644
--- a/src/operator/subgraph/subgraph_property.h
+++ b/src/operator/subgraph/subgraph_property.h
@@ -192,7 +192,7 @@ class SubgraphSelectorV2 {
/*!
* \brief Post processes pre-selected subgraph nodes. Return a list of nodes
that
* users want to keep in subgraph(s).
- * \param candidates re-selected subgraph nodes to filt
+ * \param candidates re-selected subgraph nodes to filter
* \return a list of nodes to keep
*/
virtual std::vector<BiDirectedNode*> Filter(const
std::vector<BiDirectedNode*>& candidates) {
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h
b/src/operator/tensor/elemwise_binary_scalar_op.h
index f18cc16c44..85e8da2c67 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -33,9 +33,6 @@
#include "../elemwise_op_common.h"
#include "../../common/alm.h"
#include "elemwise_unary_op.h"
-#if MXNET_USE_ONEDNN == 1
-#include "operator/nn/dnnl/dnnl_power_scalar-inl.h"
-#endif
namespace mxnet {
namespace op {
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
index 564951138a..e6273cd9d0 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
@@ -21,9 +21,10 @@
* \file elemwise_binary_scalar_op_extended.cc
* \brief CPU Implementation of extended binary scalar functions.
*/
-#include "./elemwise_unary_op.h"
-#include "./elemwise_binary_op.h"
-#include "./elemwise_binary_scalar_op.h"
+#include "elemwise_unary_op.h"
+#include "elemwise_binary_op.h"
+#include "elemwise_binary_scalar_op.h"
+#include "operator/nn/dnnl/dnnl_pow_mul_scalar-inl.h"
namespace mxnet {
namespace op {
@@ -66,9 +67,9 @@ void PowerComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<mxnet::NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<mxnet::NDArray>& outputs) {
- if (SupportDNNLPower(inputs[0])) {
+ if (SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
- DNNLRun(DNNLPowerForward, attrs, ctx, inputs[0], req[0], outputs[0]);
+ DNNLRun(DNNLPowMulScalarForward<false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(
(BinaryScalarOp::Compute<cpu, mshadow_op::power>), attrs, ctx, inputs,
req, outputs);
} else {
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py
b/tests/python/dnnl/subgraphs/subgraph_common.py
index c1224d25e3..a23ba3b69c 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -219,27 +219,27 @@ def check_quantize(net_original, data_shapes, out_type,
name='conv',
rtol=0.1, atol=atol, etol=0.2)
-def check_fusion(net_original, data_shapes, attrs_dict, check_fp32_fusion=True,
+def check_fusion(net_original, data_shapes, attrs_dict, input_type='float32',
check_fusion=True,
check_quantization=True, out_types=['uint8', 'int8', 'auto'],
dedup_subgraph=True,
quantize_mode='full'):
net_original.initialize()
net_original.hybridize(static_alloc=False, static_shape=False)
one_shape = isinstance(data_shapes, tuple)
- data_min = -1.0
- data_max = 1.0
+ data_min = -1.0 if input_type == 'float32' else -10
+ data_max = 1.0 if input_type == 'float32' else 10
if one_shape:
# replace one shape with list of shapes with one element to follow later
the same schema
data_shapes=[data_shapes]
data = []
for shape in data_shapes:
- data.append(mx.np.random.uniform(size=shape, dtype='float32',
device=mx.cpu(),
+ data.append(mx.np.random.uniform(size=shape, dtype=input_type,
device=mx.cpu(),
low=data_min, high=data_max))
net_original(*data)
net_fusion = copy.copy(net_original)
sym, _ = net_original.export(None)
- if check_fp32_fusion:
+ if check_fusion:
if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
check_quantization = False
data_min = 0
diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
index e7ae08682a..de680ae352 100644
--- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
@@ -320,7 +320,7 @@ def function_fc_add(data_shape, add_op, quantize_mode,
fc_out_add, flatten, relu
data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)]
check_fusion(net, data_shapes, attrs,
out_types=[out_type],
- check_fp32_fusion=(quantize_mode is None),
+ check_fusion=(quantize_mode is None),
check_quantization=(quantize_mode is not None) and flatten,
quantize_mode=quantize_mode)
diff --git a/tests/python/dnnl/subgraphs/test_pow_mul_subgraph.py
b/tests/python/dnnl/subgraphs/test_pow_mul_subgraph.py
new file mode 100644
index 0000000000..eea90cc441
--- /dev/null
+++ b/tests/python/dnnl/subgraphs/test_pow_mul_subgraph.py
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import pytest
+from subgraph_common import check_fusion
+from subgraph_common import DATA_SHAPE
+from mxnet.gluon import nn
+
[email protected]_np
[email protected]('data_shape', DATA_SHAPE)
[email protected]('input_type', ['float32', 'int32', 'int8'])
[email protected]('exponent', [2, 2.0])
[email protected]('multiplier', [3, 3.0])
+def test_pow_mul_fuse(data_shape, input_type, exponent, multiplier):
+ class TestPowMulFuse(nn.HybridBlock):
+ def __init__(self):
+ super(TestPowMulFuse, self).__init__()
+
+ def forward(self, input, *args):
+ return (input**exponent)*multiplier
+
+ net = TestPowMulFuse()
+ attrs = {'sg_pow_mul_scalar' : []}
+ check_fusion(net, data_shape, attrs,
+ input_type=input_type,
+ check_quantization=False)