This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e9840b8 [FEAUTURE] Fuses FC + elemwise_add operators for oneDNN
(#20821)
e9840b8 is described below
commit e9840b89f3305c446e4bab0402232abc68bde6dc
Author: Andrzej Kotłowski <[email protected]>
AuthorDate: Fri Jan 28 14:57:28 2022 +0100
[FEAUTURE] Fuses FC + elemwise_add operators for oneDNN (#20821)
* Fix elemwise_add post quantization pass
* Align naming convention with convolution operator
Convolution uses convention data_name_[min|max] which is object
oriented and more readable.
* Fuse FC with elemwise_add
* [TEST] Add functional tests for FC + add operators fusion
* [TESTS] Disable quantization check for not supported cases
* Take into account already fused elemwise operation
Fix for fusing already fused FC + relu/activation for floating point is
added.
Fusing elemwise_add with FC with already fused relu/activation is
blocked due to accuracy issues.
---
src/operator/nn/dnnl/dnnl_fully_connected-inl.h | 85 +++++
src/operator/nn/dnnl/dnnl_fully_connected.cc | 3 +
src/operator/subgraph/build_subgraph.cc | 11 +
src/operator/subgraph/dnnl/dnnl_fc.cc | 424 +++++++++++++--------
src/operator/subgraph/dnnl/dnnl_fc_property.h | 3 +
src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h | 291 ++++++++++++++
.../subgraph/dnnl/dnnl_post_quantize_property.h | 8 +-
.../subgraph/dnnl/dnnl_subgraph_property.cc | 4 +
tests/python/dnnl/subgraphs/subgraph_common.py | 154 ++++++--
tests/python/dnnl/subgraphs/test_conv_subgraph.py | 40 +-
tests/python/dnnl/subgraphs/test_fc_subgraph.py | 144 ++++++-
11 files changed, 936 insertions(+), 231 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
index c30ad4b..4196b15 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
+++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
@@ -28,6 +28,8 @@
#if MXNET_USE_ONEDNN == 1
+#include <memory>
+#include <unordered_map>
#include <string>
#include <vector>
@@ -41,6 +43,8 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
bool quantized;
bool enable_float_output;
bool with_eltwise;
+ bool with_sum;
+ bool first_quantization_pass; // True for operator created during first
quantization pass
dmlc::optional<float> min_calib_range; // min float value calculated from
calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from
calibration dataset
dmlc::optional<bool> channel_wise_quantize;
@@ -54,6 +58,10 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
DMLC_DECLARE_FIELD(with_eltwise)
.set_default(false)
.describe("Whether there's a post with_eltwise after FullyConnected
operator");
+ DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
+ DMLC_DECLARE_FIELD(first_quantization_pass)
+ .set_default(false)
+ .describe("True for first quantization pass");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
@@ -76,9 +84,86 @@ struct DNNLFCFullParam {
FullyConnectedParam default_param;
DNNLFCParam dnnl_param;
DNNLPostEltwiseParam eltwise_param;
+ float sum_scale = {1.0f};
std::vector<float> output_scales = {0.0f};
};
+static inline size_t GetInSumIndex(const DNNLFCFullParam& param) {
+ assert(param.dnnl_param.with_sum);
+ return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1);
+}
+
+class FCInputIndex {
+ public:
+ explicit FCInputIndex(const DNNLFCFullParam full_param) {
+ auto& dnnl_param = full_param.dnnl_param;
+ const bool has_bias = !full_param.default_param.no_bias;
+ const bool quantized = dnnl_param.quantized;
+ const bool sum_input_quantized =
+ quantized && dnnl_param.with_sum && !dnnl_param.enable_float_output;
+ const bool channel_wise = quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
+ dnnl_param.channel_wise_quantize.value();
+
+ // Calculate position of particular input in the input vector:
+ int index = 0;
+ data = index++;
+ weight = index++;
+ bias = has_bias ? index++ : 0;
+ sum = dnnl_param.with_sum ? index++ : 0;
+ num_base = index; // note number of base inputs
+
+ data_min = quantized ? index++ : 0;
+ data_max = quantized ? index++ : 0;
+ weight_min = (quantized && !channel_wise) ? index++ : 0;
+ weight_max = (quantized && !channel_wise) ? index++ : 0;
+ bias_min = (quantized && !channel_wise && has_bias) ? index++ : 0;
+ bias_max = (quantized && !channel_wise && has_bias) ? index++ : 0;
+ sum_min = sum_input_quantized ? index++ : 0;
+ sum_max = sum_input_quantized ? index++ : 0;
+ num_total = index; // note number of total inputs
+ }
+
+ // Returns true if sum input exists
+ bool IsSumExist() const {
+ return sum;
+ }
+
+ // Returns true if bias input exists
+ bool IsBiasExist() const {
+ return bias;
+ }
+
+ // Returns true if sum input exists and it is float number
+ bool IsSumInputFloat() const {
+ return (sum && !sum_min);
+ }
+ int GetTotal() const {
+ return num_total;
+ }
+ int GetBase() const {
+ return num_base;
+ }
+
+ // Represent index of particular input in the input vector:
+ int data;
+ int weight;
+ int bias;
+ int sum;
+ int data_min;
+ int data_max;
+ int weight_min;
+ int weight_max;
+ int bias_min;
+ int bias_max;
+ int sum_min;
+ int sum_max;
+
+ private:
+ int num_base; // Number of standard inputs
+ int num_total; // Number of total inputs: standard + additional needed for
+ // quantization
+};
+
dnnl::inner_product_forward::primitive_desc GetFCFwdImpl(const
DNNLFCFullParam& full_param,
const bool is_train,
const NDArray& data,
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected.cc
b/src/operator/nn/dnnl/dnnl_fully_connected.cc
index eca90b7..6f04b19 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected.cc
+++ b/src/operator/nn/dnnl/dnnl_fully_connected.cc
@@ -53,6 +53,9 @@ dnnl::inner_product_forward::primitive_desc
GetFCFwdImpl(const DNNLFCFullParam&
full_param.eltwise_param.alpha,
full_param.eltwise_param.beta);
}
+ if (full_param.dnnl_param.with_sum) {
+ ops.append_sum(full_param.sum_scale);
+ }
attr.set_post_ops(ops);
if (full_param.dnnl_param.quantized && full_param.output_scales.size()) {
diff --git a/src/operator/subgraph/build_subgraph.cc
b/src/operator/subgraph/build_subgraph.cc
index ef1218b..4acaa22 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -749,6 +749,17 @@ void CreateSubgraphNode(nnvm::Graph* g,
for (BiDirectedNode* dest_node : subgraph_nodes) {
sn->outputs.erase(dest_node->node);
}
+ }
+ }
+
+ // Set outputs according to current inputs
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ auto& e = n->inputs[i];
+ // update input entries' source simple nodes' outputs map
+ nnvm::Node* node = e.node.get();
+ if (indexed_graph.exist(node)) {
+ const auto nid = indexed_graph.node_id(node);
+ BiDirectedNode* sn = simple_nodes[nid].get();
sn->outputs[n.get()].push_back(i);
}
}
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 49887fe..8ead3e7 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -25,7 +25,10 @@
#if MXNET_USE_ONEDNN == 1
+#include <memory>
#include <string>
+#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -62,8 +65,8 @@ class SgDNNLFCOp {
private:
bool initialized_{false};
- bool channel_wise_runtime_{false};
bool reorder_data_{false};
+ bool inplace_{false};
nnvm::Symbol subgraph_sym_;
DNNLFCFullParam full_param_;
dnnl_args_map_t args_;
@@ -72,83 +75,123 @@ class SgDNNLFCOp {
std::shared_ptr<dnnl::memory> cached_out_mem_;
NDArray cached_weight_;
NDArray cached_bias_;
- float cached_min_data_;
- float cached_max_data_;
- float cached_min_weight_;
- float cached_max_weight_;
- float cached_min_bias_;
- float cached_max_bias_;
+ float cached_data_min_;
+ float cached_data_max_;
+ float cached_weight_min_;
+ float cached_weight_max_;
+ float cached_sum_min_;
+ float cached_sum_max_;
+ float cached_bias_min_;
+ float cached_bias_max_;
size_t weight_ver_;
size_t bias_ver_;
- float cached_min_output_;
- float cached_max_output_;
+ float cached_output_min_;
+ float cached_output_max_;
float data_scale_{0.0f};
std::vector<float> weight_scales_;
- size_t total_num_inputs_;
- size_t total_num_outputs_;
};
void SgDNNLFCOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
- auto& dnnl_param = full_param_.dnnl_param;
- auto& default_param = full_param_.default_param;
- bool has_bias = !default_param.no_bias;
- size_t base_num_inputs = has_bias ? 3 : 2;
- size_t base_num_outputs = 1;
-
- float min_data = 0.0f;
- float max_data = 0.0f;
- float min_weight = 0.0f;
- float max_weight = 0.0f;
- float min_bias = 0.0f;
- float max_bias = 0.0f;
-
- if (!initialized_) {
- if (dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize) {
- channel_wise_runtime_ = true;
+ auto& dnnl_param = full_param_.dnnl_param;
+ auto& default_param = full_param_.default_param;
+ const bool has_bias = !default_param.no_bias;
+ const bool quantized = dnnl_param.quantized;
+ const bool out_quantized = dnnl_param.quantized &&
!dnnl_param.enable_float_output;
+ const bool channel_wise = quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
+ dnnl_param.channel_wise_quantize.value();
+
+ const FCInputIndex idx(full_param_);
+
+ CHECK_EQ(in_data.size(), idx.GetTotal());
+
+ int index = 0;
+ const int out_index = index++;
+ const int out_min_index = out_quantized ? index++ : 0;
+ const int out_max_index = out_quantized ? index++ : 0;
+ CHECK_EQ(out_data.size(), index); // index is equal to total number of
outputs
+
+ float data_min = 0.0f;
+ float data_max = 0.0f;
+ float weight_min = 0.0f;
+ float weight_max = 0.0f;
+ float bias_min = 0.0f;
+ float bias_max = 0.0f;
+
+ const float sum_min = idx.sum_min ?
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
+ const float sum_max = idx.sum_max ?
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
+ NDArray data = in_data[idx.data];
+ const NDArray& weight = in_data[idx.weight];
+ NDArray output;
+
+ if (dnnl_param.with_sum) {
+ if (!initialized_) {
+ // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace
option,
+ // which make check (req[out_index] == kWriteInplace) useless.
+ auto in_dnnl_mem = static_cast<const
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
+ auto out_dnnl_mem = static_cast<const
dnnl::memory*>(out_data[out_index].GetDNNLData());
+ if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle()) {
+ inplace_ = true;
+ }
}
-
- total_num_inputs_ = base_num_inputs;
- total_num_outputs_ = base_num_outputs;
- if (dnnl_param.quantized) {
- total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) :
(base_num_inputs * 3);
- total_num_outputs_ =
- dnnl_param.enable_float_output ? base_num_outputs :
(base_num_outputs * 3);
+ if (inplace_) {
+ output = in_data[idx.sum];
+ } else {
+ // Not in place: copy in_data[idx.sum] into outputs[out_index].
+ auto in_dnnl_mem = static_cast<const
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
+ auto out_dnnl_mem = static_cast<const
dnnl::memory*>(out_data[out_index].GetDNNLData());
+ if (out_data[out_index].dtype() == mshadow::kInt32) {
+ auto mem_desc = in_dnnl_mem->get_desc();
+ auto this_dtype = get_dnnl_type(mshadow::kInt32);
+ mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
+ dnnl_mem_ptr tmp_mem(new dnnl::memory(
+ mem_desc, CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(tmp_mem);
+ DNNLStream::Get()->RegisterPrimArgs(
+ dnnl::reorder(*in_dnnl_mem, *tmp_mem),
+ {{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
+ output = NDArray(tmp_mem);
+ } else {
+ dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
+ CpuEngine::Get()->get_engine(),
+
out_dnnl_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(tmp_mem);
+ DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
+ output = NDArray(tmp_mem);
+ }
}
+ } else {
+ output = out_data[out_index];
}
- CHECK_EQ(in_data.size(), total_num_inputs_);
- CHECK_EQ(out_data.size(), total_num_outputs_);
-
- NDArray data = in_data[fullc::kData];
- const NDArray& weight = in_data[fullc::kWeight];
- const NDArray& output = out_data[fullc::kOut];
if (dnnl_param.quantized) {
- if (!channel_wise_runtime_) {
- min_weight = in_data[base_num_inputs +
quantized_fullc::kWeightMin].data().dptr<float>()[0];
- max_weight = in_data[base_num_inputs +
quantized_fullc::kWeightMax].data().dptr<float>()[0];
+ if (!channel_wise) {
+ weight_min = in_data[idx.weight_min].data().dptr<float>()[0];
+ weight_max = in_data[idx.weight_max].data().dptr<float>()[0];
if (has_bias) {
- min_bias = in_data[base_num_inputs +
quantized_fullc::kBiasMin].data().dptr<float>()[0];
- max_bias = in_data[base_num_inputs +
quantized_fullc::kBiasMax].data().dptr<float>()[0];
+ bias_min = in_data[idx.bias_min].data().dptr<float>()[0];
+ bias_max = in_data[idx.bias_max].data().dptr<float>()[0];
}
}
- min_data = in_data[base_num_inputs +
quantized_fullc::kDataMin].data().dptr<float>()[0];
- max_data = in_data[base_num_inputs +
quantized_fullc::kDataMax].data().dptr<float>()[0];
+ data_min = in_data[idx.data_min].data().dptr<float>()[0];
+ data_max = in_data[idx.data_max].data().dptr<float>()[0];
}
if (initialized_ && dnnl_param.quantized &&
dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
- if (channel_wise_runtime_) {
- if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+ if (channel_wise) {
+ if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
+ cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
weight_ver_ != weight.version() ||
- (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
+ (has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
initialized_ = false;
}
} else {
- if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
- cached_min_weight_ != min_weight || cached_max_weight_ != max_weight
||
- (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ !=
max_bias))) {
+ if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
+ cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
+ cached_weight_min_ != weight_min || cached_weight_max_ != weight_max
||
+ (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ !=
bias_max))) {
initialized_ = false;
}
}
@@ -157,17 +200,19 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
if (!initialized_) {
const auto nthreads =
engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const auto engine = CpuEngine::Get()->get_engine();
- cached_min_data_ = min_data;
- cached_max_data_ = max_data;
- cached_min_weight_ = min_weight;
- cached_max_weight_ = max_weight;
+ cached_data_min_ = data_min;
+ cached_data_max_ = data_max;
+ cached_weight_min_ = weight_min;
+ cached_weight_max_ = weight_max;
weight_ver_ = weight.version();
cached_weight_ = weight;
+ cached_sum_min_ = sum_min;
+ cached_sum_max_ = sum_max;
if (has_bias) {
- cached_min_bias_ = min_bias;
- cached_max_bias_ = max_bias;
- bias_ver_ = in_data[fullc::kBias].version();
- cached_bias_ = in_data[fullc::kBias];
+ cached_bias_min_ = bias_min;
+ cached_bias_max_ = bias_max;
+ bias_ver_ = in_data[idx.bias].version();
+ cached_bias_ = in_data[idx.bias];
} else {
cached_bias_ = NDArray();
}
@@ -210,13 +255,13 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
bool support_channelwise_scale = false;
if (dnnl_param.quantized) {
CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
- data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_,
cached_max_data_);
+ data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_,
cached_data_max_);
bool fuse_requantize = false;
// Channelwise scaling is only supported when fusion is enabled
(requantize or dequantize).
if (dnnl_param.min_calib_range.has_value() &&
dnnl_param.max_calib_range.has_value()) {
- cached_min_output_ = dnnl_param.min_calib_range.value();
- cached_max_output_ = dnnl_param.max_calib_range.value();
+ cached_output_min_ = dnnl_param.min_calib_range.value();
+ cached_output_max_ = dnnl_param.max_calib_range.value();
support_channelwise_scale = true;
fuse_requantize = true;
}
@@ -227,7 +272,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
// True True True
// True False Error
// False True/False False
- if (channel_wise_runtime_ && !support_channelwise_scale) {
+ if (channel_wise && !support_channelwise_scale) {
LOG(FATAL)
<< "Currently, channel-wise quantization requires fuse requantize
or dequantize."
<< " Please make sure the `min_calib_range` and `max_calib_range`
are set when only"
@@ -236,7 +281,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
<< " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` and
"
<< " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true
(default is false)";
}
- support_channelwise_scale = support_channelwise_scale &&
channel_wise_runtime_;
+ support_channelwise_scale = support_channelwise_scale && channel_wise;
if (support_channelwise_scale) {
MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
@@ -248,51 +293,56 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
} else {
weight_scales_.resize(1);
weight_scales_[0] =
- GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_,
cached_max_weight_);
+ GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_,
cached_weight_max_);
if (has_bias) {
- float bias_scale = GetQuantizeScale(mshadow::kInt8,
cached_min_bias_, cached_max_bias_);
- float bias_int32_rescale = data_scale_ * weight_scales_[0] /
bias_scale;
- // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set the
maximum value
- // of bias to INT_MAX / 2.
- float bias_max_rescale =
- MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_,
cached_max_bias_) / bias_scale;
- if (bias_int32_rescale > bias_max_rescale) {
- // avoid overflow on bias
- bias_int32_rescale = bias_max_rescale;
- float weight_rescale =
- bias_int32_rescale * bias_scale / data_scale_ /
weight_scales_[0];
- int8_t* weight_ptr = weight.data().dptr<int8_t>();
- size_t weight_size = weight.shape().Size();
+ if (cached_bias_.dtype() == mshadow::kInt8) {
+ float bias_scale = GetQuantizeScale(mshadow::kInt8,
cached_bias_min_, cached_bias_max_);
+
+ float bias_int32_rescale = data_scale_ * weight_scales_[0] /
bias_scale;
+ // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
+ // the maximum value of bias to INT_MAX / 2.
+ float bias_max_rescale =
+ MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_,
cached_bias_max_) / bias_scale;
+ if (bias_int32_rescale > bias_max_rescale) {
+ // avoid overflow on bias
+ bias_int32_rescale = bias_max_rescale;
+ float weight_rescale =
+ bias_int32_rescale * bias_scale / data_scale_ /
weight_scales_[0];
+ int8_t* weight_ptr = weight.data().dptr<int8_t>();
+ size_t weight_size = weight.shape().Size();
#pragma omp parallel for num_threads(nthreads)
- for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
- weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+ for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
+ weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+ }
+ weight_scales_[0] *= weight_rescale;
}
- weight_scales_[0] *= weight_rescale;
- }
- NDArray bias = in_data[fullc::kBias];
- cached_bias_ =
- NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true,
mshadow::kInt32);
- int8_t* bias_ptr = bias.data().dptr<int8_t>();
- int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
- size_t bias_size = bias.shape().Size();
+ NDArray bias = in_data[fullc::kBias];
+ cached_bias_ =
+ NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true,
mshadow::kInt32);
+ int8_t* bias_ptr = bias.data().dptr<int8_t>();
+ int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
+ size_t bias_size = bias.shape().Size();
+
#pragma omp parallel for num_threads(nthreads)
- for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
- quantized_bias_ptr[i] = std::round(bias_ptr[i] *
bias_int32_rescale);
+ for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
+ quantized_bias_ptr[i] = std::round(bias_ptr[i] *
bias_int32_rescale);
+ }
}
}
}
size_t num_channel = cached_weight_.shape()[0];
+ float out_scale = 1.0f;
if (fuse_requantize || dnnl_param.enable_float_output) {
float tmp_scale_ = 1.0f;
if (fuse_requantize) {
if (dnnl_param.with_eltwise) {
tmp_scale_ = 1.0 / data_scale_;
full_param_.eltwise_param.scale =
- GetQuantizeScale(output.dtype(), cached_min_output_,
cached_max_output_);
+ GetQuantizeScale(output.dtype(), cached_output_min_,
cached_output_max_);
} else {
- tmp_scale_ = GetQuantizeScale(output.dtype(), cached_min_output_,
cached_max_output_) /
- data_scale_;
+ out_scale = GetQuantizeScale(output.dtype(), cached_output_min_,
cached_output_max_);
+ tmp_scale_ = out_scale / data_scale_;
}
} else {
tmp_scale_ = 1.0 / data_scale_;
@@ -314,26 +364,33 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct,
cpu>::Launch(
s,
1,
- &cached_min_output_,
- &cached_max_output_,
- &min_data,
- &max_data,
- &min_weight,
- &max_weight);
+ &cached_output_min_,
+ &cached_output_max_,
+ &data_min,
+ &data_max,
+ &weight_min,
+ &weight_max);
} else {
mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct,
cpu>::Launch(
s,
1,
- &cached_min_output_,
- &cached_max_output_,
- &min_data,
- &max_data,
- &min_weight,
- &max_weight);
+ &cached_output_min_,
+ &cached_output_max_,
+ &data_min,
+ &data_max,
+ &weight_min,
+ &weight_max);
}
full_param_.output_scales.resize(0);
+ out_scale = data_scale_ * weight_scales_[0];
}
- }
+
+ if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+ float sum_in_scale =
+ GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_,
cached_sum_max_);
+ full_param_.sum_scale = out_scale / sum_in_scale;
+ }
+ } // if (dnnl_param.quantized)
fwd_.reset(new DNNLFullyConnectedForward(full_param_,
ctx.is_train,
@@ -357,10 +414,11 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
weight_scales_,
false);
} else {
- const auto def_weight_mem = weight.GetDNNLData();
+ const auto def_weight_mem = static_cast<const
dnnl::memory*>(weight.GetDNNLData());
if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
- cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
- auto cached_weight_mem = cached_weight_.GetDNNLData();
+ auto weight_desc = fwd_->fwd_pd.weights_desc();
+ cached_weight_ = NDArray(weight_desc);
+ auto cached_weight_mem = static_cast<const
dnnl::memory*>(cached_weight_.GetDNNLData());
std::unordered_map<int, dnnl::memory> args(
{{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO,
*cached_weight_mem}});
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem,
*cached_weight_mem),
@@ -368,17 +426,32 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
}
}
- const auto data_mem = data.GetDNNLData();
+ const auto data_mem = static_cast<const dnnl::memory*>(data.GetDNNLData());
cached_data_mem_ = std::make_shared<dnnl::memory>(data_mem->get_desc(),
engine);
args_[DNNL_ARG_SRC] = *cached_data_mem_;
- args_[DNNL_ARG_WEIGHTS] = *cached_weight_.GetDNNLData();
+ args_[DNNL_ARG_WEIGHTS] = *static_cast<const
dnnl::memory*>(cached_weight_.GetDNNLData());
if (has_bias)
- args_[DNNL_ARG_BIAS] = *cached_bias_.GetDNNLData();
+ args_[DNNL_ARG_BIAS] = *static_cast<const
dnnl::memory*>(cached_bias_.GetDNNLData());
args_[DNNL_ARG_DST] = *cached_out_mem_;
initialized_ = true;
}
+ if (dnnl_param.with_sum) {
+ const auto& output_mem = output.GetDNNLData();
+ const auto& out_mem_desc = output_mem->get_desc();
+ auto dst_mem_desc = fwd_->fwd_pd.dst_desc();
+ if (out_mem_desc != dst_mem_desc) {
+ auto tmp_out_mem = output.GetDNNLDataReorder(dst_mem_desc);
+ dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
+ dnnl_mem_ptr new_out_mem(new dnnl::memory(
+ dst_mem_desc, CpuEngine::Get()->get_engine(),
output_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(new_out_mem);
+ DNNLMemoryCopy(*tmp_out_mem, new_out_mem.get());
+ output = NDArray(new_out_mem);
+ }
+ }
+
if (reorder_data_) {
data = data.Reorder2Default();
}
@@ -392,10 +465,11 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
DNNLStream::Get()->Submit();
if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
- float* min_output_ptr =
out_data[quantized_fullc::kOutMin].data().dptr<float>();
- float* max_output_ptr =
out_data[quantized_fullc::kOutMax].data().dptr<float>();
- *min_output_ptr = cached_min_output_;
- *max_output_ptr = cached_max_output_;
+ float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
+ float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
+
+ *output_min_ptr = cached_output_min_;
+ *output_max_ptr = cached_output_max_;
}
}
@@ -450,23 +524,25 @@ static void SgDNNLFCParamParser(nnvm::NodeAttrs* attrs) {
static std::vector<std::string> SgDNNLFCListInputNames(const NodeAttrs& attrs)
{
auto const& full_param =
nnvm::get<DNNLFCFullParam>(attrs.parsed);
+ auto const& dnnl_param = full_param.dnnl_param;
std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
- if (full_param.dnnl_param.quantized) {
- bool channel_wise = false;
- if (full_param.dnnl_param.channel_wise_quantize.has_value() &&
- full_param.dnnl_param.channel_wise_quantize) {
- channel_wise = true;
- }
- input_names.emplace_back("min_data");
- input_names.emplace_back("max_data");
+ if (dnnl_param.quantized) {
+ const bool channel_wise =
+ dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize;
+ input_names.emplace_back("data_min");
+ input_names.emplace_back("data_max");
if (!channel_wise) {
- input_names.emplace_back("min_weight");
- input_names.emplace_back("max_weight");
+ input_names.emplace_back("weight_min");
+ input_names.emplace_back("weight_max");
if (!full_param.default_param.no_bias) {
- input_names.emplace_back("min_bias");
- input_names.emplace_back("max_bias");
+ input_names.emplace_back("bias_min");
+ input_names.emplace_back("bias_max");
}
}
+ if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+ input_names.emplace_back("sum_min");
+ input_names.emplace_back("sum_max");
+ }
}
return input_names;
}
@@ -477,19 +553,19 @@ static std::vector<std::string>
SgDNNLFCListOutputNames(const NodeAttrs& attrs)
if (full_param.dnnl_param.enable_float_output)
return std::vector<std::string>{"output"};
else
- return std::vector<std::string>{"output", "min_output", "max_output"};
+ return std::vector<std::string>{"output", "output_min", "output_max"};
} else {
return std::vector<std::string>{"output"};
}
}
template <typename T>
-static inline void FillBaseInputOutputInfo(const FullyConnectedParam& param,
+static inline void FillBaseInputOutputInfo(const DNNLFCFullParam& param,
std::vector<T>* base_in_attrs,
std::vector<T>* base_out_attrs,
std::vector<T>* in_attrs,
std::vector<T>* out_attrs) {
- auto base_num_inputs = param.no_bias ? 2 : 3;
+ auto base_num_inputs = FCInputIndex(param).GetBase();
base_out_attrs->push_back(out_attrs->at(0));
for (int i = 0; i < base_num_inputs; ++i) {
@@ -504,8 +580,7 @@ static bool SgDNNLFCInferShape(const nnvm::NodeAttrs& attrs,
if (full_param.dnnl_param.quantized) {
mxnet::ShapeVector base_in_shapes;
mxnet::ShapeVector base_out_shapes;
- FillBaseInputOutputInfo(
- full_param.default_param, &base_in_shapes, &base_out_shapes,
in_shapes, out_shapes);
+ FillBaseInputOutputInfo(full_param, &base_in_shapes, &base_out_shapes,
in_shapes, out_shapes);
bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes,
&base_out_shapes);
for (size_t i = 0; i < in_shapes->size(); ++i) {
@@ -531,26 +606,43 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs&
attrs,
std::vector<int>* out_types) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized) {
- bool channel_wise = false;
- if (full_param.dnnl_param.channel_wise_quantize.has_value() &&
- full_param.dnnl_param.channel_wise_quantize) {
- channel_wise = true;
- }
- size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;
- CHECK(in_types->at(0) == mshadow::kInt8 || in_types->at(0) ==
mshadow::kUint8)
- << "QuantizedFullyConnected only supports int8/uint8 input, while " <<
in_types->at(0)
- << " is given.";
- for (size_t i = 1; i < in_types->size(); ++i) {
- if (channel_wise) {
- TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
- } else {
- if (i < base_num_inputs) {
- TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
+ const bool channel_wise =
full_param.dnnl_param.channel_wise_quantize.has_value() &&
+ full_param.dnnl_param.channel_wise_quantize;
+ const FCInputIndex idx(full_param);
+
+ CHECK(in_types->at(idx.data) == mshadow::kInt8 || in_types->at(idx.data)
== mshadow::kUint8)
+ << "QuantizedFullyConnected data input only supports int8/uint8,
while "
+ << in_types->at(idx.data) << " is given.";
+ if (channel_wise) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kFloat32);
+ if (idx.IsBiasExist()) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kFloat32);
+ }
+ } else {
+ TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kInt8);
+ if (idx.IsBiasExist()) {
+ if (in_types->at(idx.bias) == -1) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kInt32);
} else {
- TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ CHECK(in_types->at(idx.bias) == mshadow::kInt8 ||
+ in_types->at(idx.bias) == mshadow::kInt32)
+ << "QuantizedFullyConnected bias input only supports int8/int32,
while "
+ << in_types->at(idx.bias) << " is given.";
}
}
}
+ if (idx.IsSumExist()) {
+ if (full_param.dnnl_param.enable_float_output) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32);
+ } else {
+ CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum)
== mshadow::kUint8)
+ << "QuantizedFullyConnected sum input only supports int8/uint8,
while "
+ << in_types->at(idx.sum) << " is given.";
+ }
+ }
+ for (size_t i = idx.data_min; i < in_types->size(); ++i) {
+ TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ }
if (full_param.dnnl_param.enable_float_output) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
@@ -583,8 +675,7 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs&
attrs,
if (full_param.dnnl_param.quantized) {
std::vector<int> base_in_attrs;
std::vector<int> base_out_attrs;
- FillBaseInputOutputInfo(
- full_param.default_param, &base_in_attrs, &base_out_attrs, in_attrs,
out_attrs);
+ FillBaseInputOutputInfo(full_param, &base_in_attrs, &base_out_attrs,
in_attrs, out_attrs);
bool ret = DefaultSubgraphOpStorageType(
attrs, dev_mask, dispatch_mode, &base_in_attrs, &base_out_attrs);
@@ -606,6 +697,15 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs&
attrs,
}
}
+std::vector<std::pair<int, int>> SgDNNLFCInplaceOption(const NodeAttrs& attrs)
{
+ auto const& param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
+ if (param.dnnl_param.with_sum) {
+ return std::vector<std::pair<int, int>>{{FCInputIndex(param).sum, 0}};
+ } else {
+ return std::vector<std::pair<int, int>>();
+ }
+}
+
static OpStatePtr CreateSgDNNLFCState(const nnvm::NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
@@ -641,13 +741,16 @@ static bool SgDNNLAvoidFCQuantizeInput(const NodeAttrs&
attrs,
const std::string quantize_granularity)
{
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
std::unordered_set<size_t> avoid_indexes;
+ FCInputIndex idx(full_param);
if (quantize_granularity == "channel-wise") {
avoid_indexes.insert(fullc::kWeight); // weight
if (!full_param.default_param.no_bias) {
avoid_indexes.insert(fullc::kBias); // bias
}
}
-
+ if (idx.IsSumInputFloat()) {
+ avoid_indexes.insert(idx.sum);
+ }
return avoid_indexes.count(index_to_check);
}
@@ -656,17 +759,7 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected)
.describe(R"code(_sg_onednn_fully_connected)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
- auto num_inputs = full_param.default_param.no_bias ? 2 : 3;
- if (full_param.dnnl_param.quantized) {
- if (full_param.dnnl_param.channel_wise_quantize.has_value() &&
- full_param.dnnl_param.channel_wise_quantize) {
- return num_inputs + 2; // min_data, max_data
- } else {
- return num_inputs * 3;
- }
- } else {
- return num_inputs;
- }
+ return FCInputIndex(full_param).GetTotal();
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
@@ -691,6 +784,7 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected)
})
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
DefaultSubgraphOpMutableInputs)
.set_attr<std::string>("key_var_num_args", "num_args")
+ .set_attr<nnvm::FInplaceOption>("FInplaceOption", SgDNNLFCInplaceOption)
.set_attr<FQuantizable>("FQuantizable",
[](const NodeAttrs& attrs) { return
QuantizeType::kMust; })
.set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLFCQuantizedOp)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index 64fd507..b22c5ef 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -194,6 +194,9 @@ class SgDNNLFCProperty : public SubgraphProperty {
auto& sub_name = node->op()->name;
if (sub_name == "FullyConnected") {
node_name << "fully_connected_";
+ if (HasAttr("quantize") && GetAttr<bool>("quantize")) {
+ n->attrs.dict["first_quantization_pass"] = "True";
+ }
} else if (SupportDNNLFCEltwiseFusion(sub_name)) {
node_name << "eltwise_";
n->attrs.dict["with_eltwise"] = "True";
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
new file mode 100644
index 0000000..4af89c9
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ \file
+ \brief For fusing FullyConnected operator with element-wise add.
+
+ Element-wise add operator is replaced by DNNL FC "sum" post operator.
+ It adds FC results to existing values in output. For quantized integer
version
+ this output is scaled to the proper range.
+*/
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "../../tensor/matrix_op-inl.h"
+#include "../common.h"
+#include "dnnl_fc-inl.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool EndsWith(std::string const& value, std::string const& ending) {
+ if (ending.size() > value.size()) {
+ return false;
+ } else {
+ return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
+ }
+}
+
+class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
+ private:
+ /*! \brief pattern match status */
+ enum SelectStatus {
+ kFail = 0,
+ kStart,
+ kSuccess,
+ };
+
+ bool quantized_;
+ SelectStatus status_ = kFail;
+ std::vector<const BiDirectedNode*> matched_list_;
+
+ public:
+ explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
+
+ bool Select(const BiDirectedNode& seed_node,
+ const std::shared_ptr<NodeAttr>& node_attr) override {
+ const auto n = seed_node.node;
+ if (n->op() == Op::Get("_sg_onednn_fully_connected") &&
SupportDNNLAttr(node_attr) &&
+ (seed_node.outputs.size() == 1)) {
+ auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
+ if ((!quantized_ && !fc_param.dnnl_param.first_quantization_pass) ||
+ (fc_param.dnnl_param.quantized &&
!fc_param.dnnl_param.with_eltwise)) {
+ // Start subgraph when fusing for floats (quantized_ is false for DNNL
backend) or
+ // when FC is already quantized (second pass for DNNL_QUANTIZE) but
not already fuzed
+ // with elemwise operator.
+ status_ = kStart;
+ matched_list_.clear();
+ matched_list_.push_back(&seed_node);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode&
input_node) override {
+ return false;
+ }
+
+ bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode&
output_node) override {
+ const auto cur_n = cur_node.node;
+ const auto output_n = output_node.node;
+ if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
+ return false;
+ }
+ // If n isn't the last matched node, then we encoutered an internal
+ // branch, we should pop out the node behind n and stop fusion.
+ if (matched_list_.back() != &cur_node) {
+ if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) !=
matched_list_.end()) {
+ while (matched_list_.back() != &cur_node) {
+ matched_list_.pop_back();
+ }
+ }
+ status_ = kSuccess;
+ return false;
+ }
+
+ switch (status_) {
+ case kStart:
+ // Find _contrib_quantized_elemwise_add or elemwise_add
+ if (EndsWith(output_n->op()->name, "elemwise_add")) {
+ if (quantized_) {
+ auto const& fc_param =
nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
+ if (!fc_param.dnnl_param.enable_float_output) {
+ // For quantized graph, when FC floating point output is not
enabled
+ // elementwise add must also be quantized (min and max value
have to be already stored
+ // in elementwise add).
+ CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
+ }
+ }
+ matched_list_.push_back(&output_node);
+ status_ = kSuccess;
+ return true;
+ }
+ default:
+ status_ = kFail;
+ return false;
+ }
+ }
+
+ std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
+ if (status_ == kSuccess) {
+ return candidates;
+ } else {
+ return std::vector<BiDirectedNode*>(0);
+ }
+ }
+
+ void Reset() override {
+ CHECK_GE(matched_list_.size(), 1);
+ auto new_selector = SgDNNLFCSumFuseSelector(quantized_);
+ new_selector.Select(*matched_list_[0], nullptr);
+ *this = new_selector;
+ }
+};
+
+class SgDNNLFCSumFuseProperty : public SubgraphProperty {
+ public:
+ SgDNNLFCSumFuseProperty() {}
+
+ static SubgraphPropertyPtr Create() {
+ static const std::string& name = "DNNL fuse FullyConnected with sum";
+ auto property =
std::make_shared<SgDNNLFCSumFuseProperty>();
+ property->SetAttr<std::string>("property_name", name);
+ property->SetAttr<bool>("inference_only", true);
+ if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) {
+ property->SetAttr<bool>("disable", true);
+ }
+ return property;
+ }
+
+ nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+ const int subgraph_id = 0) const override
{
+ nnvm::ObjectPtr fc_node = nullptr;
+ nnvm::ObjectPtr ew_add_node = nullptr;
+
+ DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
+ if (node->is_variable()) {
+ return;
+ }
+ auto& sub_name = node->op()->name;
+ if (sub_name == "_sg_onednn_fully_connected") {
+ fc_node = node;
+ } else if (EndsWith(sub_name, "elemwise_add")) {
+ ew_add_node = node;
+ }
+ });
+
+ CHECK_NOTNULL(fc_node);
+ if (ew_add_node != nullptr) {
+ CHECK_NOTNULL(fc_node->attrs.subgraphs[0]);
+ auto subgraph_output_node = fc_node->attrs.subgraphs[0]->outputs[0].node;
+ nnvm::Symbol new_sym;
+ // Create a new elemwise_add node to not alter the original one.
+ // It is needed in subgraph to properly calculate InferShape.
+ nnvm::ObjectPtr n = nnvm::Node::Create();
+ n->attrs.op = Op::Get("elemwise_add");
+ n->attrs.name = ew_add_node->attrs.name;
+
+ if (ew_add_node->inputs[0].node == fc_node) {
+ n->inputs.emplace_back(subgraph_output_node);
+ n->inputs.emplace_back(ew_add_node->inputs[1]);
+ } else {
+ n->inputs.emplace_back(ew_add_node->inputs[0]);
+ n->inputs.emplace_back(subgraph_output_node);
+ }
+ new_sym.outputs.emplace_back(n);
+ fc_node->attrs.subgraphs.clear();
+
fc_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+ fc_node->attrs.dict["with_sum"] = "True";
+ fc_node->attrs.dict.erase("first_quantization_pass"); // Removed as not
needed any longer
+ fc_node->op()->attr_parser(&(fc_node->attrs));
+ }
+ return fc_node;
+ }
+
+ SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+ bool quantized = HasAttr("quantize") ? GetAttr<bool>("quantize") : false;
+ auto selector = std::make_shared<SgDNNLFCSumFuseSelector>(quantized);
+ return selector;
+ }
+
+ void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+ std::vector<nnvm::NodeEntry*>* output_entries)
const override {
+ // Connect all extern output entries to output[0]
+ for (size_t i = 0; i < output_entries->size(); ++i) {
+ auto entry_ptr = output_entries->at(i);
+ *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+ }
+ }
+
+ void ConnectSubgraphInputs(const nnvm::ObjectPtr n,
+ std::vector<nnvm::NodeEntry*>* input_entries,
+ std::vector<nnvm::NodeEntry>* orig_input_entries)
const override {
+ auto sym = n->attrs.subgraphs[0];
+ auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
+ std::unordered_set<const nnvm::Node*> node_sets;
+ DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
+ if (node->is_variable()) {
+ return;
+ }
+ node_sets.insert(node.get());
+ if (EndsWith(node->op()->name, "elemwise_add")) {
+ const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
+ // Make sure fc output is the left operand of the add operator, if not:
+ // - swap inputs of add operator
+ // - switch add operands sequence to ensure that
+ // the tensor (sum_tensor) to which FC output is added is the last
input.
+ if (node_sets.count(node->inputs[1].node.get())) {
+ // Example of input_entries reordering for channel-wise quantized
graph:
+ // sum_tensor.data --> fc.data
+ // fc.data --> fc.weight0
+ // fc.weight0 --> fc.bias0
+ // fc.bias0 --> sum_tensor.data
+ // fc_out.min --> fc_out.min
+ // fc_out.max --> fc_out.max
+ // sum_tensor.min --> sum_tensor.min
+ // sum_tensor.max --> sum_tensor.max
+ std::swap(node->inputs[0], node->inputs[1]);
+ std::rotate(input_entries->begin(),
+ input_entries->begin() + 1,
+ input_entries->begin() + base_inputs);
+ std::rotate(orig_input_entries->begin(),
+ orig_input_entries->begin() + 1,
+ orig_input_entries->begin() + base_inputs);
+ } else {
+ // Example of input_entries reordering for channel-wise quantized
graph:
+ // fc.data --> fc.data
+ // fc.weight0 --> fc.weight0
+ // fc.bias0 --> fc.bias0
+ // fc_out.min --> sum_tensor.data
+ // fc_out.max --> fc_out.min
+ // sum_tensor.data --> fc_out.max
+ // sum_tensor.min --> sum_tensor.min
+ // sum_tensor.max --> sum_tensor.max
+ const int not_rotated_end =
+ (fc_param.dnnl_param.quantized &&
!fc_param.dnnl_param.enable_float_output) ? 2 : 0;
+
+ std::rotate(input_entries->begin() + base_inputs - 1,
+ input_entries->end() - 1 - not_rotated_end,
+ input_entries->end() - not_rotated_end);
+ std::rotate(orig_input_entries->begin() + base_inputs - 1,
+ orig_input_entries->end() - 1 - not_rotated_end,
+ orig_input_entries->end() - not_rotated_end);
+ }
+ }
+ });
+ n->inputs = *orig_input_entries;
+ }
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index d9cb6c0..1aa52ca 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -45,6 +45,9 @@ const std::set<std::string> support_req_fusion_op =
{"_contrib_quantized_elemwis
"_sg_onednn_selfatt_qk",
"_sg_onednn_selfatt_valatt",
"_sg_onednn_batch_dot"};
+
+const std::set<const Op*> no_enable_float_output =
{Op::Get("_contrib_quantized_elemwise_add"),
+
Op::Get("_sg_onednn_conv")};
} // namespace
class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
@@ -110,7 +113,8 @@ class SgDNNLPostQuantizeSelector : public
SubgraphSelectorV2 {
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
matched_list.emplace_back(&new_node);
status = SelectStatus::kRequantize;
- if (raw_node->op() == Op::Get("_sg_onednn_conv")) {
+ if ((raw_node->op() == Op::Get("_sg_onednn_conv")) ||
+ (raw_node->op() ==
Op::Get("_contrib_quantized_elemwise_add"))) {
status = SelectStatus::kSuccess;
}
return true;
@@ -210,7 +214,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
// When only fused quantized operator and requantize, set
min/max_cablib_range,
// When fused quantized operator + requantize + dequantize, set dequantize
flag to true.
- if (dequantize_node != nullptr) {
+ if ((dequantize_node != nullptr) &&
(no_enable_float_output.count(fuse_node->op()) == 0)) {
fuse_node->attrs.dict["enable_float_output"] = "True";
} else {
fuse_node->attrs.dict["min_calib_range"] =
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 8f8fc44..b3b23e1 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -28,6 +28,7 @@
#include "dnnl_post_quantize_property.h"
#include "dnnl_transformer_qk_property.h"
#include "dnnl_transformer_valatt_property.h"
+#include "dnnl_fc_sum_fuse.h"
namespace mxnet {
namespace op {
@@ -43,6 +44,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN,
SgDNNLBNReLUProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCSumFuseProperty);
MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context",
Context::CPU());
@@ -55,6 +57,8 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLBatchDotProperty)
.set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLPostQuantizeAlignScaleProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCSumFuseProperty)
+ .set_attr("quantize", true);
} // namespace op
} // namespace mxnet
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py
b/tests/python/dnnl/subgraphs/subgraph_common.py
index 37b14c8..b3bf5b0 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -114,8 +114,27 @@ def check_qsym_scale_align(qsym):
assert max_calib_range == v['max_calib_range']
-def check_quantize(net_original, data_shape, out_type, name='conv',
- check_calibration=True, check_scale_align=False):
+def check_fusion_parameter(sym, attrs_dict):
+ for name, attrs in attrs_dict.items():
+ if name in config:
+ op_name = config[name][OP_NAME]
+ else:
+ op_name = name
+ assert ''.join(sym.get_internals().list_outputs()).find(op_name) != -1
+ if len(attrs):
+ found = False
+ for k, v in sym.attr_dict().items():
+ if k.find('_quantize') != -1:
+ continue
+ if k.find(op_name) != -1:
+ found = True
+ for attr_name, attr_value in attrs.items():
+ assert v[attr_name].lower() == attr_value.lower()
+ assert found
+
+def check_quantize(net_original, data_shapes, out_type, name='conv',
+ check_calibration=True, check_scale_align=False,
quantize_mode='full',
+ attrs_dict={}):
quantize_granularity_list = ['tensor-wise']
if name == 'fc':
quantize_granularity_list += ['channel-wise']
@@ -125,92 +144,108 @@ def check_quantize(net_original, data_shape, out_type,
name='conv',
net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
min_value = -1 if out_type != 'uint8' else 0
- data = mx.np.random.uniform(min_value, 1.0, size=data_shape,
dtype='float32', ctx=mx.current_device())
-
- outputs = net_original(data)
+ one_shape = isinstance(data_shapes, tuple)
+ if one_shape:
+ # replace one shape with list of shapes with one element inside to follow
later the same schema
+ data_shapes=[data_shapes]
+ data = []
+ for shape in data_shapes:
+ data.append(mx.np.random.uniform(min_value, 1.0, size=shape,
dtype='float32', device=mx.cpu()))
+
+ outputs = net_original(*data)
for output in outputs:
output.wait_to_read()
ref_out = outputs
- calib_data = mx.gluon.data.DataLoader(data, batch_size=1)
+ dataArray= mx.gluon.data.ArrayDataset(*data)
+
+ calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
for quantize_granularity in quantize_granularity_list:
qnet = quantization.quantize_net(net_original,
- ctx=mx.current_device(),
+ device=mx.cpu(),
exclude_layers=None,
exclude_operators=None,
quantized_dtype=out_type,
calib_mode='naive',
calib_data=calib_data,
num_calib_batches=1,
- quantize_mode='full',
+ quantize_mode=quantize_mode,
quantize_granularity=quantize_granularity)
qsym, _ = qnet.export(None)
+ check_fusion_parameter(qsym, attrs_dict)
if check_calibration:
check_qsym_calibrated(qsym, out_type, name=name)
if check_scale_align:
check_qsym_scale_align(qsym)
- quantized_out = qnet(data)
+ quantized_out = qnet(*data)
for i in range(len(ref_out)):
min_range = mx.np.min(ref_out[i]).item()
max_range = mx.np.max(ref_out[i]).item()
atol = 0.1 * max(abs(min_range), abs(max_range))
- assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(),
rtol=0.1, atol=atol, etol=0.2)
+ assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(),
+ rtol=0.1, atol=atol, etol=0.2)
-def check_fusion(net_original, data_shape, attrs_dict, check_fp32_fusion=True,
check_quantization=True,
- out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True):
+def check_fusion(net_original, data_shapes, attrs_dict, check_fp32_fusion=True,
+ check_quantization=True, out_types=['uint8', 'int8', 'auto'],
dedup_subgraph=True,
+ quantize_mode='full'):
net_original.initialize()
net_original.hybridize(static_alloc=False, static_shape=False)
- data = mx.np.random.uniform(size=data_shape, dtype='float32',
ctx=mx.current_device())
- net_original(data)
+ one_shape = isinstance(data_shapes, tuple)
+ data_min = -1.0
+ data_max = 1.0
+
+ if one_shape:
+ # replace one shape with list of shapes with one element to follow later
the same schema
+ data_shapes=[data_shapes]
+ data = []
+ for shape in data_shapes:
+ data.append(mx.np.random.uniform(size=shape, dtype='float32',
device=mx.cpu(),
+ low=data_min, high=data_max))
+ net_original(*data)
net_fusion = copy.copy(net_original)
sym, params = net_original.export(None)
if check_fp32_fusion:
- data_min = -1.0
- data_max = 1.0
if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
check_quantization = False
data_min = 0
sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph,
skip_infer=True)
- for name, attrs in attrs_dict.items():
- if name in config:
- op_name = config[name][OP_NAME]
- else:
- op_name = name
- assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1
- if len(attrs):
- found = False
- for k, v in sym_sg.attr_dict().items():
- if k.find(op_name) != -1:
- found = True
- for attr_name, attr_value in attrs.items():
- assert v[attr_name].lower() == attr_value.lower()
- assert found
-
- data = mx.np.random.uniform(size=data_shape, low=data_min, high=data_max)
- out_unfused = net_original(data)
-
- net_fusion.optimize_for(data, backend=SG_PASS_NAME)
- out_fused = net_fusion(data)
+ check_fusion_parameter(sym_sg, attrs_dict)
+ if data_min == 0 and mx.npx.is_np_default_dtype():
+ # regenerate inputs if they have different range or data type
+ data = []
+ for shape in data_shapes:
+ data.append(mx.np.random.uniform(size=shape, device=mx.cpu(),
low=data_min, high=data_max))
+ out_unfused = net_original(*data)
+
+ net_fusion.optimize_for(*data, backend=SG_PASS_NAME)
+ out_fused = net_fusion(*data)
assert_almost_equal(out_unfused.asnumpy(), out_fused.asnumpy(), rtol=1e-3,
atol=1e-1)
if check_quantization:
# fp32 to int8
for out_type in out_types:
- check_quantize(net_original, data_shape, out_type, name=name)
+ check_quantize(net_original, data_shapes, out_type,
name=list(attrs_dict.keys())[0],
+ quantize_mode=quantize_mode, attrs_dict=attrs_dict)
def check_neg_fusion(net_original, attrs_name=None, excluded_attrs=None,
- data_shapes=(4,4,10,10), name='conv'):
+ data_shapes=[(4,4,10,10)], name='conv'):
op_name = config[name][OP_NAME]
+ one_shape = isinstance(data_shapes, tuple)
+ if one_shape:
+ # replace one shape with list of shapes with one element to follow later
the same schema
+ data_shapes = [data_shapes]
+ data = []
+ for shape in data_shapes:
+ data.append(mx.np.random.uniform(size=shape))
- data_nd = mx.np.random.uniform(size=data_shapes)
net_original.initialize()
net_original.hybridize()
- net_original(data_nd)
+ net_original(*data)
sym, _ = net_original.export(None)
sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=True, skip_infer=True)
@@ -221,4 +256,41 @@ def check_neg_fusion(net_original, attrs_name=None,
excluded_attrs=None,
for attr in attrs_name:
assert v[attr] == 'true'
for exc_attr in excluded_attrs:
- assert exc_attr not in v.keys()
+ assert exc_attr not in v.keys(), exc_attr + " atribute shouldn't exist"
+
+
+
+def check_neg_fusion_quantized(net_original, attrs_name=None,
excluded_attrs=None,
+ data_shapes=[(4,4,10,10)], name='conv'):
+ op_name = config[name][OP_NAME]
+ net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
+ one_shape = isinstance(data_shapes, tuple)
+ if one_shape:
+ # replace one shape with list of shapes with one element inside to follow
later the same schema
+ data_shapes=[data_shapes]
+ data = []
+ for shape in data_shapes:
+ data.append(mx.np.random.uniform(size=shape, dtype='float32',
device=mx.cpu()))
+
+ dataArray= mx.gluon.data.ArrayDataset(*data)
+ calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
+
+ qnet = quantization.quantize_net(net_original,
+ device=mx.cpu(),
+ exclude_layers=None,
+ exclude_operators=None,
+ quantized_dtype='int8',
+ calib_mode='naive',
+ calib_data=calib_data,
+ num_calib_batches=1,
+ quantize_mode='full',
+ quantize_granularity='tensor-wise')
+ qsym, _ = qnet.export(None)
+ attrs_dict = qsym.attr_dict()
+ for k, v in attrs_dict.items():
+ if k.find(op_name) != -1:
+ for attr in attrs_name:
+ assert v[attr] == 'true'
+ for exc_attr in excluded_attrs:
+ assert exc_attr not in v.keys(), exc_attr + " atribute shouldn't exist"
+
diff --git a/tests/python/dnnl/subgraphs/test_conv_subgraph.py
b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
index 6b6169b..e7dac8f 100644
--- a/tests/python/dnnl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
@@ -91,7 +91,7 @@ def test_pos_conv_add(use_bias, data_shape):
attr = {'conv': {'with_sum': 'true'}}
net = ConvAdd(use_bias=use_bias)
- check_fusion(net, data_shape, attr)
+ check_fusion(net, data_shape, attr, check_quantization=False)
@mx.util.use_np
@@ -112,14 +112,14 @@ def test_pos_conv_add2(no_bias, data_shape):
attr = {'conv': {'with_sum': 'true'}}
net = ConvAdd(use_bias=True)
- check_fusion(net, data_shape, attr)
+ check_fusion(net, data_shape, attr, check_quantization=False)
@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('alg,quantize', [
("relu", False), #TODO(bgawrych): investigate
- ("sigmoid", True),
+ ("sigmoid", False),
("log_sigmoid", False),
("mish", False),
("tanh", False), #TODO(bgawrych): investigate
@@ -162,11 +162,11 @@ def test_pos_conv_act_add(data_shape, alg, quantize,
use_bias):
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
- ("sigmoid", True),
- ("log_sigmoid", True),
- ("mish", True),
- ("tanh", True),
- ("softrelu", True),
+ ("sigmoid", False),
+ ("log_sigmoid", False),
+ ("mish", False),
+ ("tanh", False),
+ ("softrelu", False),
("relu6", True),
("leakyrelu", True),
("gelu", True)
@@ -200,14 +200,14 @@ def test_pos_conv_bn_act(use_bias, data_shape, alg,
quantize):
@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('alg,quantize', [
- ("relu", True),
- ("sigmoid", True),
- ("log_sigmoid", True),
- ("mish", True),
- ("tanh", True),
+ ("relu", False),
+ ("sigmoid", False),
+ ("log_sigmoid", False),
+ ("mish", False),
+ ("tanh", False),
#("softrelu", True), #TODO(bgawrych): failing fusion check - difference in
random single element
- ("relu6", True),
- ("leakyrelu", True),
+ ("relu6", False),
+ ("leakyrelu", False),
("gelu", False) #TODO: for True we get assert instead of not fusing pattern
])
@pytest.mark.parametrize('use_bias', [True, False])
@@ -321,11 +321,11 @@ def test_pos_concat_scale_align(data_shape, out_type):
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
- ("sigmoid", True),
- ("log_sigmoid", True),
- ("mish", True),
- ("tanh", True),
- ("softrelu", True),
+ ("sigmoid", False),
+ ("log_sigmoid", False),
+ ("mish", False),
+ ("tanh", False),
+ ("softrelu", False),
("relu6", True),
("leakyrelu", True),
("gelu", True)
diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
index 223a55d..c63bb9a 100644
--- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
@@ -17,7 +17,7 @@
import mxnet as mx
import pytest
-from subgraph_common import check_fusion, check_neg_fusion
+from subgraph_common import check_fusion, check_neg_fusion,
check_neg_fusion_quantized
from subgraph_common import CustomNormalInit, DATA_SHAPE, TailNegBlock
from mxnet.contrib import quantization
from mxnet.gluon import nn
@@ -89,9 +89,11 @@ def test_fc_eltwise(data_shape, use_bias, flatten, alg):
out = mx.np.clip(fc_out, 0, 1.0)
return out
+ not_quant_fuze = ['sigmoid', 'log_sigmoid', 'softrelu', 'tanh', 'mish',
'square', 'square_root',
+ 'exp']
attrs = {'fc': {'with_eltwise': 'true'}}
net = FCEltwise(use_bias, flatten, alg)
- check_fusion(net, data_shape, attrs, check_quantization=flatten)
+ check_fusion(net, data_shape, attrs, check_quantization=flatten and not alg
in not_quant_fuze)
@mx.util.use_np
@@ -148,7 +150,7 @@ def test_quantized_fc_bias_overflow(data_min, data_max,
weight_min, weight_max):
conv1 = mx.npx.fully_connected(x, num_hidden=64,
weight=self.weight.data(x.device),
no_bias=False,
bias=self.bias.data(x.device))
return conv1
-
+
def infer_shape(self, x, *args):
self.weight.shape = (64, x.shape[x.ndim-1])
self.bias.shape = (64,)
@@ -232,3 +234,139 @@ def test_fc_identity_eltwise(identity_node):
'sg_onednn_fully_connected_eltwise_1' : {'with_eltwise': 'true'}}
net = FCIdentityEltwise(identity_node)
check_fusion(net, data_shape, attrs, check_quantization=False)
+
+
+def function_fc_add(data_shape, add_op, quantize_mode, fc_out_add, flatten,
relu, out_type):
+ class FCWithSumExample(nn.HybridBlock):
+ def __init__(self, num_hidden, add_op, fc_out_add, **kwargs):
+ super(FCWithSumExample, self).__init__(**kwargs)
+ self.fca = nn.Dense(units=num_hidden, flatten=flatten)
+ self.elemwise_add = (add_op == 'elemwise_add')
+ self.fc_out_as_rhs = (fc_out_add == 'rhs')
+ self.relu = (relu == 'leaky_relu')
+
+ def forward(self, data1a, data2):
+ fc_out = self.fca(data1a)
+ if self.relu:
+ fc_out = mx.npx.leaky_relu(fc_out, act_type='gelu')
+ if self.fc_out_as_rhs:
+ if self.elemwise_add:
+ sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(),
fc_out.as_nd_ndarray()).as_np_ndarray()
+ else:
+ sum1 = data2 + fc_out
+ else:
+ if self.elemwise_add:
+ sum1 = mx.nd.elemwise_add(fc_out.as_nd_ndarray(),
data2.as_nd_ndarray()).as_np_ndarray()
+ else:
+ sum1 = fc_out + data2
+ return sum1
+
+ attrs = {'fc': {'with_sum': 'true'}}
+ if quantize_mode is not None:
+ attrs['fc']['quantized'] = 'true'
+ if quantize_mode == 'smart':
+ attrs['fc']['enable_float_output'] = 'true'
+ num_hidden=10
+ net = FCWithSumExample(num_hidden, add_op, fc_out_add)
+ if flatten:
+ data_shapes = [data_shape, (data_shape[0], num_hidden)]
+ else:
+ data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)]
+ check_fusion(net, data_shapes, attrs,
+ out_types=[out_type],
+ check_fp32_fusion=(quantize_mode is None),
+ check_quantization=(quantize_mode is not None) and flatten,
+ quantize_mode=quantize_mode)
+
[email protected]_np
[email protected]('data_shape', DATA_SHAPE)
[email protected]('relu', ['noleaky_re', 'leaky_relu'])
[email protected]('flatten', ['flat', 'nofl'])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('add_op', ['elemwise_add'])
+def test_fc_add(data_shape, add_op, fc_out_add, flatten, relu):
+ function_fc_add(data_shape, add_op, None, fc_out_add, flatten=='flat', relu,
None)
+
[email protected]_np
[email protected](1234) # Seed set because the test is not robust enough to
operate on random data
[email protected]('data_shape', DATA_SHAPE)
[email protected]('quantize_mode', ['full', 'smart'])
[email protected]('out_type', ['int8', 'auto'])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('add_op', ['elemwise_add'])
+def test_fc_add_quantized(data_shape, add_op, quantize_mode, fc_out_add,
out_type):
+ function_fc_add(data_shape, add_op, quantize_mode, fc_out_add, True,
'noleaky_re', out_type)
+
+
+class NegFCAdd(nn.HybridBlock):
+ #
+ # data --------------------------> 'add_op' ------------>
+ # / \
+ # sg_oned_dnn_fully_connected ----> npi_add -->
+ # \ /
+ # npi_multiply_scalar -->
+ def __init__(self, num_hidden, add_op, fc_out_add, scaled_fc_out, flatten,
**kwargs):
+ super(NegFCAdd, self).__init__(**kwargs)
+ self.fca = nn.Dense(units=num_hidden, flatten=flatten)
+ self.elemwise_add = (add_op == 'elemwise_add')
+ self.fc_out_as_rhs = (fc_out_add == 'rhs')
+ self.scaled_fc_out_as_rhs = (scaled_fc_out == 's_rhs')
+
+ def forward(self, data1a, data2):
+ fc_out = self.fca(data1a)
+ scaled_fc_out = fc_out * 200.0
+ if self.fc_out_as_rhs:
+ if self.elemwise_add:
+ sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(),
fc_out.as_nd_ndarray()).as_np_ndarray()
+ else:
+ sum1 = data2 + fc_out
+ else:
+ if self.elemwise_add:
+ sum1 = mx.nd.elemwise_add(fc_out.as_nd_ndarray(),
data2.as_nd_ndarray()).as_np_ndarray()
+ else:
+ sum1 = fc_out + data2
+ if self.scaled_fc_out_as_rhs:
+ sum2 = sum1 + scaled_fc_out
+ else:
+ sum2 = scaled_fc_out + sum1
+ return sum2
+
[email protected]_np
[email protected]('add_op', ['elemwise_add'])
[email protected]('data_shape', [DATA_SHAPE[0]])
[email protected]('flatten', ['flat', 'nofl'])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('scaled_fc_out', ['s_lhs', 's_rhs'])
+def test_neg_fc_add(data_shape, add_op, flatten, fc_out_add, scaled_fc_out):
+ '''
+ Test if FullyConnected operator which output is not used for only one
'add_op' input is not fused.
+ See NegFCAdd for used graph example
+ '''
+ flatten = (flatten == 'flat')
+ num_hidden = 10
+ net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, flatten)
+ if flatten:
+ data_shapes = [data_shape, (data_shape[0], num_hidden)]
+ else:
+ data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)]
+ attrs = []
+ excluded_attrs = ['with_sum']
+ check_neg_fusion(net, attrs, excluded_attrs, data_shapes, name='fc')
+
[email protected]_np
[email protected]('add_op', ['elemwise_add'])
[email protected]('data_shape', [DATA_SHAPE[1]])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('scaled_fc_out', ['s_lhs', 's_rhs'])
+def test_neg_fc_add_quantized(data_shape, add_op, fc_out_add, scaled_fc_out):
+ '''
+ Test if FullyConnected operator which output is not used for only one
'add_op' input
+ is not fused for quantized model.
+ See NegFCAdd for used graph example.
+ '''
+ num_hidden = 10
+ net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, True)
+ data_shapes = [data_shape, (data_shape[0], num_hidden)]
+ attrs = []
+ excluded_attrs = ['with_sum']
+ check_neg_fusion_quantized(net, attrs, excluded_attrs, data_shapes,
name='fc')