This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1eb3344 Support Quantized Fully Connected by INT8 GEMM (#12922)
1eb3344 is described below
commit 1eb334457259d15044534802d5d19545a0960038
Author: Hao Li <[email protected]>
AuthorDate: Sat Dec 15 13:35:08 2018 +0800
Support Quantized Fully Connected by INT8 GEMM (#12922)
* add quantized fully connect support
* disable qfc cpu case since s8u8s32 is only supported by MKL BLAS library
* retrigger to ci testing
* move implementation to cc file and add STORAGE_TYPE_ASSIGN_CHECK
* fix typo bug
* retrigger the ci test
* fix typo bug
* retrigger ci
* retrigger the ci test
* retrigger the ci
* retrigger the ci test
* retrigger ci test
* fix indent issue
* retrigger the ci
* retrigger the ci test
* add verbose message
* update log message
* using range for loop
* using for auto range
* enable MKL BLAS ci test
* fix typo issue
* use TYPE_ASSIGN_CHECK
* retrigger the ci
---
.../quantization/quantized_fully_connected.cc | 159 ++++++++++++++++++++-
tests/python/quantization/test_quantization.py | 26 +++-
2 files changed, 177 insertions(+), 8 deletions(-)
diff --git a/src/operator/quantization/quantized_fully_connected.cc
b/src/operator/quantization/quantized_fully_connected.cc
index e334fe7..64ce73b 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -23,11 +23,17 @@
* \brief
* \author Ziheng Jiang, Jun Wu
*/
+#include <vector>
+#include "quantization_utils.h"
#include "../nn/fully_connected-inl.h"
namespace mxnet {
namespace op {
+namespace quantized_fc {
+enum QuantizedfcOpResource {kTempSpace};
+}
+
bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
@@ -79,6 +85,151 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs&
attrs,
return true;
}
+bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ *dispatch_mode = DispatchMode::kFCompute;
+ if (dev_mask == mshadow::cpu::kDevMask) {
+ *dispatch_mode = DispatchMode::kFComputeEx;
+ }
+
+ for (auto &v : *out_attrs) {
+ v = kDefaultStorage;
+ if (common::stype_string(v).compare("unknown") == 0) {
+ return false;
+ }
+ }
+
+ for (auto &v : *in_attrs) {
+ v = kDefaultStorage;
+ if (common::stype_string(v).compare("unknown") == 0) {
+ return false;
+ }
+ }
+ return true;
+}
+
+struct QuantizedSumInitKernelWithBias {
+ // init sum data with bias for matrix b (n)
+ MSHADOW_XINLINE static void Map(int i, int32_t *out,
+ const int8_t *bias, const float *min_out,
+ const float *max_out, const float *min_bias,
+ const float *max_bias) {
+ typedef int32_t T1;
+ typedef int8_t T2;
+ using mshadow::red::limits::MinValue;
+ using mshadow::red::limits::MaxValue;
+ float float_for_one_out_quant =
+ MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
+ float float_for_one_bias_quant =
+ MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
+ if (float_for_one_out_quant != 0) {
+ out[i] = bias[i] * float_for_one_bias_quant /
+ float_for_one_out_quant;
+ } else {
+ LOG(INFO) << "float_for_one_out_quant is 0,"
+ << " need to check the why MaxAbs(*min_out, *max_out) of
out_data is 0!";
+ out[i] = 0;
+ }
+ }
+};
+
+
+template<typename SrcType>
+void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data) {
+#if MSHADOW_USE_MKL == 1
+ const FullyConnectedParam& param =
nnvm::get<FullyConnectedParam>(attrs.parsed);
+ using namespace mshadow;
+ using namespace mxnet_op;
+ size_t num_inputs = param.no_bias ? 2 : 3;
+ CHECK_EQ(in_data.size(), num_inputs * 3);
+ CHECK_EQ(out_data.size(), 3U);
+ const NDArray& data = in_data[0];
+ const NDArray& weight = in_data[1];
+ const NDArray& out = out_data[0];
+ TShape dshape = data.shape();
+ TShape wshape = weight.shape();
+ TShape oshape = out.shape();
+ auto output_temp = out.data().dptr<int32_t>();
+ auto weight_temp = weight.data().dptr<SrcType>();
+ auto data_temp = data.data().dptr<SrcType>();
+ const int omp_threads =
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ const float alpha = 1.0f;
+ const float beta = 1.0f;
+ const CBLAS_OFFSET offsetc = CblasFixOffset;
+ const MKL_INT8 oa = 0;
+ const MKL_INT8 ob = 0;
+ MKL_INT32 oc = 0;
+ const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1,
dshape.ndim());
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ // cblas_gemm_s8u8s32 required first matrix must be uint8
+ // shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
+ int shift = 128;
+ Tensor<cpu, 1, uint8_t> shiftdata =
+ ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
+ Shape1(m * k), s);
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int i = 0; i < m * k; ++i) {
+ shiftdata.dptr_[i] = data_temp[i] + shift;
+ }
+
+ Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
+ out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
+ in_data[num_inputs].data().dptr<float>(),
in_data[num_inputs+1].data().dptr<float>(),
+ in_data[num_inputs+2].data().dptr<float>(),
in_data[num_inputs+3].data().dptr<float>());
+ if (!param.no_bias) {
+ const NDArray& bias = in_data[2];
+ Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n,
out.data().dptr<int32_t>(),
+ bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(),
+ out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(),
+ in_data[8].data().dptr<float>());
+ } else {
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int i = 0; i < m * n; ++i) {
+ output_temp[i] = 0;
+ }
+ }
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < k; ++j) {
+ output_temp[i] -= shift * weight_temp[i * k + j];
+ }
+ }
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int i = n; i < m * n; ++i) {
+ output_temp[i] = output_temp[i % n];
+ }
+ cblas_gemm_s8u8s32(CblasRowMajor,
+ CblasNoTrans,
+ CblasTrans,
+ offsetc,
+ m,
+ n,
+ k,
+ alpha,
+ shiftdata.dptr_,
+ k,
+ oa,
+ weight.data().dptr<SrcType>(),
+ k,
+ ob,
+ beta,
+ out.data().dptr<int32_t>(),
+ n,
+ &oc);
+#else
+ LOG(FATAL) << "Quantized fully connected operator relies on
cblas_gemm_s8u8s32"
+ << " which is only supported by MKL BLAS."
+ << " Please build MXNet with USE_BLAS=mkl to leverage this
operator.";
+#endif
+}
+
NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
.describe(R"code(Fully Connected operator for input, weight and bias data type
of int8,
and accumulates in type int32 for the output. For each argument, two more
arguments of type
@@ -112,7 +263,14 @@ and max thresholds representing the threholds for
quantizing the float32 output
})
.set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
+.set_attr<FInferStorageType>("FInferStorageType",
QuantizedFullyConnectedStorageType)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) {
return true; })
+.set_attr<FComputeEx>("FComputeEx<cpu>",
+ QuantizedFullyConnectedForward<int8_t>)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
.add_argument("data", "NDArray-or-Symbol", "Input data.")
.add_argument("weight", "NDArray-or-Symbol", "weight.")
.add_argument("bias", "NDArray-or-Symbol", "bias.")
@@ -135,6 +293,5 @@ NNVM_REGISTER_OP(FullyConnected)
}
return node;
});
-
} // namespace op
} // namespace mxnet
diff --git a/tests/python/quantization/test_quantization.py
b/tests/python/quantization/test_quantization.py
index 518b696..3ff4b69 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -26,6 +26,7 @@ from common import with_seed
from mxnet.module import Module
from mxnet.io import NDArrayIter
import unittest
+import operator
def is_test_for_gpu():
return mx.current_context().device_type == 'gpu'
@@ -278,8 +279,15 @@ def test_quantized_pooling():
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype,
flatten=True):
if mx.current_context().device_type != 'gpu':
- print('skipped testing quantized_fc on cpu since it is not
supported yet')
- return
+ hasMKL = False;
+ for key in os.environ.keys():
+ if operator.eq(key, "BUILD_TAG"):
+ if os.environ['BUILD_TAG'].find("MKL") != -1:
+ hasMKL = True
+ break
+ if hasMKL == False:
+ print('skipped testing quantized_fc on cpu since s8u8s32 is
only supported by MKL BLAS library')
+ return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_fc for gpu uint8 since it is not
supported yet')
return
@@ -291,16 +299,16 @@ def test_quantized_fc():
fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(),
grad_req='null')
if qdtype == 'uint8':
data_low = 0.0
- data_high = 127.0
+ data_high = 63.0
else:
- data_low = -127.0
- data_high = 127.0
+ data_low = -63.0
+ data_high = 63.0
fc_fp32_exe.arg_dict[arg_names[0]][:] =
mx.nd.random.uniform(low=data_low, high=data_high,
shape=data_shape).astype('int32')
- fc_fp32_exe.arg_dict[arg_names[1]][:] =
mx.nd.random.uniform(low=-127.0, high=127.0,
+ fc_fp32_exe.arg_dict[arg_names[1]][:] =
mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[1]).astype('int32')
if not no_bias:
- fc_fp32_exe.arg_dict[arg_names[2]][:] =
mx.nd.random.uniform(low=-127.0, high=127.0,
+ fc_fp32_exe.arg_dict[arg_names[2]][:] =
mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[2]).astype('int32')
output = fc_fp32_exe.forward()[0]
@@ -343,6 +351,10 @@ def test_quantized_fc():
check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
+ check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
+ check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
+ check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
+ check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)
@with_seed()
def test_quantized_flatten():