This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1eb3344  Support Quantized Fully Connected by INT8 GEMM (#12922)
1eb3344 is described below

commit 1eb334457259d15044534802d5d19545a0960038
Author: Hao Li <[email protected]>
AuthorDate: Sat Dec 15 13:35:08 2018 +0800

    Support Quantized Fully Connected by INT8 GEMM (#12922)
    
    * add quantized fully connect support
    
    * disable qfc cpu case since s8u8s32 is only supported by MKL BLAS library
    
    * retrigger to ci testing
    
    * move implementation to cc file and add  STORAGE_TYPE_ASSIGN_CHECK
    
    * fix typo bug
    
    * retrigger the ci test
    
    * fix typo bug
    
    * retrigger ci
    
    * retrigger the ci test
    
    * retrigger the ci
    
    * retrigger the ci test
    
    * retrigger ci test
    
    * fix indent issue
    
    * retrigger the ci
    
    * retrigger the ci test
    
    * add verbose message
    
    * update log message
    
    * using range for loop
    
    * using for auto range
    
    * enable MKL BLAS ci test
    
    * fix typo issue
    
    * use TYPE_ASSIGN_CHECK
    
    * retrigger the ci
---
 .../quantization/quantized_fully_connected.cc      | 159 ++++++++++++++++++++-
 tests/python/quantization/test_quantization.py     |  26 +++-
 2 files changed, 177 insertions(+), 8 deletions(-)

diff --git a/src/operator/quantization/quantized_fully_connected.cc 
b/src/operator/quantization/quantized_fully_connected.cc
index e334fe7..64ce73b 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -23,11 +23,17 @@
  * \brief
  * \author Ziheng Jiang, Jun Wu
 */
+#include <vector>
+#include "quantization_utils.h"
 #include "../nn/fully_connected-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace quantized_fc {
+enum QuantizedfcOpResource {kTempSpace};
+}
+
 bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
                                   std::vector<TShape> *in_shape,
                                   std::vector<TShape> *out_shape) {
@@ -79,6 +85,151 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& 
attrs,
   return true;
 }
 
+bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
+                                        const int dev_mask,
+                                        DispatchMode* dispatch_mode,
+                                        std::vector<int> *in_attrs,
+                                        std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+
+  for (auto &v : *out_attrs) {
+    v = kDefaultStorage;
+    if (common::stype_string(v).compare("unknown") == 0) {
+      return false;
+    }
+  }
+
+  for (auto &v : *in_attrs) {
+    v = kDefaultStorage;
+    if (common::stype_string(v).compare("unknown") == 0) {
+      return false;
+    }
+  }
+  return true;
+}
+
+struct QuantizedSumInitKernelWithBias {
+  //  init sum data with bias for matrix b (n)
+  MSHADOW_XINLINE static void Map(int i, int32_t *out,
+                                  const int8_t *bias, const float *min_out,
+                                  const float *max_out, const float *min_bias,
+                                  const float *max_bias) {
+    typedef int32_t T1;
+    typedef int8_t  T2;
+    using mshadow::red::limits::MinValue;
+    using mshadow::red::limits::MaxValue;
+    float float_for_one_out_quant  =
+        MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
+    float float_for_one_bias_quant =
+        MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
+    if (float_for_one_out_quant != 0) {
+      out[i] = bias[i] * float_for_one_bias_quant /
+          float_for_one_out_quant;
+    } else {
+      LOG(INFO) << "float_for_one_out_quant is 0,"
+                << " need to check the why MaxAbs(*min_out, *max_out) of 
out_data is 0!";
+      out[i] = 0;
+    }
+  }
+};
+
+
+template<typename SrcType>
+void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
+                                    const OpContext &ctx,
+                                    const std::vector<NDArray> &in_data,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &out_data) {
+#if MSHADOW_USE_MKL == 1
+  const FullyConnectedParam& param = 
nnvm::get<FullyConnectedParam>(attrs.parsed);
+  using namespace mshadow;
+  using namespace mxnet_op;
+  size_t num_inputs = param.no_bias ? 2 : 3;
+  CHECK_EQ(in_data.size(),  num_inputs * 3);
+  CHECK_EQ(out_data.size(), 3U);
+  const NDArray& data = in_data[0];
+  const NDArray& weight = in_data[1];
+  const NDArray& out = out_data[0];
+  TShape dshape = data.shape();
+  TShape wshape = weight.shape();
+  TShape oshape = out.shape();
+  auto output_temp = out.data().dptr<int32_t>();
+  auto weight_temp = weight.data().dptr<SrcType>();
+  auto data_temp = data.data().dptr<SrcType>();
+  const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  const float alpha = 1.0f;
+  const float beta  = 1.0f;
+  const CBLAS_OFFSET offsetc = CblasFixOffset;
+  const MKL_INT8 oa = 0;
+  const MKL_INT8 ob = 0;
+  MKL_INT32 oc = 0;
+  const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, 
dshape.ndim());
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  //  cblas_gemm_s8u8s32 required first matrix must be uint8
+  //  shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
+  int shift = 128;
+  Tensor<cpu, 1, uint8_t> shiftdata =
+    ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
+      Shape1(m * k), s);
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < m * k; ++i) {
+    shiftdata.dptr_[i] = data_temp[i] + shift;
+  }
+
+  Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
+      out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
+      in_data[num_inputs].data().dptr<float>(), 
in_data[num_inputs+1].data().dptr<float>(),
+      in_data[num_inputs+2].data().dptr<float>(), 
in_data[num_inputs+3].data().dptr<float>());
+  if (!param.no_bias) {
+    const NDArray& bias = in_data[2];
+    Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, 
out.data().dptr<int32_t>(),
+        bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(),
+        out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(),
+        in_data[8].data().dptr<float>());
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < m * n; ++i) {
+      output_temp[i] = 0;
+    }
+  }
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < n; ++i) {
+    for (int j = 0; j < k; ++j) {
+      output_temp[i] -= shift * weight_temp[i * k + j];
+    }
+  }
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = n; i < m * n; ++i) {
+    output_temp[i] = output_temp[i % n];
+  }
+  cblas_gemm_s8u8s32(CblasRowMajor,
+                     CblasNoTrans,
+                     CblasTrans,
+                     offsetc,
+                     m,
+                     n,
+                     k,
+                     alpha,
+                     shiftdata.dptr_,
+                     k,
+                     oa,
+                     weight.data().dptr<SrcType>(),
+                     k,
+                     ob,
+                     beta,
+                     out.data().dptr<int32_t>(),
+                     n,
+                     &oc);
+#else
+  LOG(FATAL) << "Quantized fully connected operator relies on 
cblas_gemm_s8u8s32"
+             << " which is only supported by MKL BLAS."
+             << " Please build MXNet with USE_BLAS=mkl to leverage this 
operator.";
+#endif
+}
+
 NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
 .describe(R"code(Fully Connected operator for input, weight and bias data type 
of int8,
 and accumulates in type int32 for the output. For each argument, two more 
arguments of type
@@ -112,7 +263,14 @@ and max thresholds representing the threholds for 
quantizing the float32 output
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
+.set_attr<FInferStorageType>("FInferStorageType", 
QuantizedFullyConnectedStorageType)
 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { 
return true; })
+.set_attr<FComputeEx>("FComputeEx<cpu>",
+    QuantizedFullyConnectedForward<int8_t>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .add_argument("data", "NDArray-or-Symbol", "Input data.")
 .add_argument("weight", "NDArray-or-Symbol", "weight.")
 .add_argument("bias", "NDArray-or-Symbol", "bias.")
@@ -135,6 +293,5 @@ NNVM_REGISTER_OP(FullyConnected)
     }
     return node;
   });
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index 518b696..3ff4b69 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -26,6 +26,7 @@ from common import with_seed
 from mxnet.module import Module
 from mxnet.io import NDArrayIter
 import unittest
+import operator
 
 def is_test_for_gpu():
     return mx.current_context().device_type == 'gpu'
@@ -278,8 +279,15 @@ def test_quantized_pooling():
 def test_quantized_fc():
     def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, 
flatten=True):
         if mx.current_context().device_type != 'gpu':
-            print('skipped testing quantized_fc on cpu since it is not 
supported yet')
-            return
+            hasMKL = False;
+            for key in os.environ.keys():
+                if operator.eq(key, "BUILD_TAG"):
+                    if os.environ['BUILD_TAG'].find("MKL") != -1:
+                        hasMKL = True
+                    break
+            if hasMKL == False:
+                print('skipped testing quantized_fc on cpu since s8u8s32 is 
only supported by MKL BLAS library')
+                return
         elif qdtype == 'uint8' and is_test_for_gpu():
             print('skipped testing quantized_fc for gpu uint8 since it is not 
supported yet')
             return
@@ -291,16 +299,16 @@ def test_quantized_fc():
         fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), 
grad_req='null')
         if qdtype == 'uint8':
             data_low = 0.0
-            data_high = 127.0
+            data_high = 63.0
         else:
-            data_low = -127.0
-            data_high = 127.0
+            data_low = -63.0
+            data_high = 63.0
         fc_fp32_exe.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
                                                                      
shape=data_shape).astype('int32')
-        fc_fp32_exe.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
+        fc_fp32_exe.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
                                                                      
shape=arg_shapes[1]).astype('int32')
         if not no_bias:
-            fc_fp32_exe.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
+            fc_fp32_exe.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
                                                                          
shape=arg_shapes[2]).astype('int32')
         output = fc_fp32_exe.forward()[0]
 
@@ -343,6 +351,10 @@ def test_quantized_fc():
         check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
         check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
         check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
+        check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
+        check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
+        check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
+        check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)
 
 @with_seed()
 def test_quantized_flatten():

Reply via email to