anko-intel commented on a change in pull request #20821:
URL: https://github.com/apache/incubator-mxnet/pull/20821#discussion_r790530664



##########
File path: src/operator/subgraph/dnnl/dnnl_fc.cc
##########
@@ -72,83 +75,123 @@ class SgDNNLFCOp {
   std::shared_ptr<dnnl::memory> cached_out_mem_;
   NDArray cached_weight_;
   NDArray cached_bias_;
-  float cached_min_data_;
-  float cached_max_data_;
-  float cached_min_weight_;
-  float cached_max_weight_;
-  float cached_min_bias_;
-  float cached_max_bias_;
+  float cached_data_min_;
+  float cached_data_max_;
+  float cached_weight_min_;
+  float cached_weight_max_;
+  float cached_sum_min_;
+  float cached_sum_max_;
+  float cached_bias_min_;
+  float cached_bias_max_;
   size_t weight_ver_;
   size_t bias_ver_;
-  float cached_min_output_;
-  float cached_max_output_;
+  float cached_output_min_;
+  float cached_output_max_;
   float data_scale_{0.0f};
   std::vector<float> weight_scales_;
-  size_t total_num_inputs_;
-  size_t total_num_outputs_;
 };
 
 void SgDNNLFCOp::Forward(const OpContext& ctx,
                          const std::vector<NDArray>& in_data,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& out_data) {
-  auto& dnnl_param        = full_param_.dnnl_param;
-  auto& default_param     = full_param_.default_param;
-  bool has_bias           = !default_param.no_bias;
-  size_t base_num_inputs  = has_bias ? 3 : 2;
-  size_t base_num_outputs = 1;
-
-  float min_data   = 0.0f;
-  float max_data   = 0.0f;
-  float min_weight = 0.0f;
-  float max_weight = 0.0f;
-  float min_bias   = 0.0f;
-  float max_bias   = 0.0f;
-
-  if (!initialized_) {
-    if (dnnl_param.channel_wise_quantize.has_value() && 
dnnl_param.channel_wise_quantize) {
-      channel_wise_runtime_ = true;
+  auto& dnnl_param         = full_param_.dnnl_param;
+  auto& default_param      = full_param_.default_param;
+  const bool has_bias      = !default_param.no_bias;
+  const bool quantized     = dnnl_param.quantized;
+  const bool out_quantized = dnnl_param.quantized && 
!dnnl_param.enable_float_output;
+  const bool channel_wise  = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
+                            dnnl_param.channel_wise_quantize.value();
+
+  const FCInputIndex idx(full_param_);
+
+  CHECK_EQ(in_data.size(), idx.GetTotal());
+
+  int index               = 0;
+  const int out_index     = index++;
+  const int out_min_index = out_quantized ? index++ : 0;
+  const int out_max_index = out_quantized ? index++ : 0;
+  CHECK_EQ(out_data.size(), index);  // index is equal to total number of 
outputs
+
+  float data_min   = 0.0f;
+  float data_max   = 0.0f;
+  float weight_min = 0.0f;
+  float weight_max = 0.0f;
+  float bias_min   = 0.0f;
+  float bias_max   = 0.0f;
+
+  const float sum_min   = idx.sum_min ? 
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
+  const float sum_max   = idx.sum_max ? 
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
+  NDArray data          = in_data[idx.data];
+  const NDArray& weight = in_data[idx.weight];
+  NDArray output;
+
+  if (dnnl_param.with_sum) {
+    if (!initialized_) {
+      // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace 
option,
+      // which make check (req[out_index] == kWriteInplace) useless.
+      auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
+      auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
+      if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle()) {
+        inplace_ = true;
+      }
     }
-
-    total_num_inputs_  = base_num_inputs;
-    total_num_outputs_ = base_num_outputs;
-    if (dnnl_param.quantized) {
-      total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : 
(base_num_inputs * 3);
-      total_num_outputs_ =
-          dnnl_param.enable_float_output ? base_num_outputs : 
(base_num_outputs * 3);
+    if (inplace_) {
+      output = in_data[idx.sum];
+    } else {
+      // Not in place: copy in_data[idx.sum] into outputs[out_index].
+      auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
+      auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
+      if (out_data[out_index].dtype() == mshadow::kInt32) {

Review comment:
       For now such configuration doesn't exist. For 'full' quantization mode 
input and output are int8 and for 'smart' quantization both are float




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to