ZhennanQin commented on a change in pull request #14641: [MKLDNN]Improve
quantizeV2 and dequantize latency
URL: https://github.com/apache/incubator-mxnet/pull/14641#discussion_r276049254
##########
File path: src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
##########
@@ -26,82 +26,104 @@
#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
#if MXNET_USE_MKLDNN == 1
-#include <string>
#include <algorithm>
+#include <string>
#include <vector>
#include "../../nn/mkldnn/mkldnn_base-inl.h"
namespace mxnet {
namespace op {
-template<typename SrcType, typename DstType>
-static void MKLDNNDequantizeComputeKer(const std::vector<NDArray> &inputs,
- const std::vector<NDArray> &outputs,
- const std::vector<OpReqType> &req) {
- using namespace mshadow;
- using namespace mxnet_op;
- using red::limits::MaxValue;
- using red::limits::MinValue;
- float real_range = 0.0;
- float quantized_range = 0.0;
- if (inputs[0].dtype() == mshadow::kUint8) {
- quantized_range = MaxAbs(MaxValue<SrcType>(), MinValue<SrcType>());
- real_range = MaxAbs(*inputs[1].data().dptr<DstType>(),
*inputs[2].data().dptr<DstType>());
- } else if (inputs[0].dtype() == mshadow::kInt8) {
- quantized_range = MinAbs(MaxValue<SrcType>(), MinValue<SrcType>());
- real_range = MaxAbs(*inputs[1].data().dptr<DstType>(),
*inputs[2].data().dptr<DstType>());
- } else {
- LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output
type";
- }
- float scale = real_range / quantized_range;
- primitive_attr attr;
- const int mask = 0;
- std::vector<float> scales = {scale};
- attr.set_output_scales(mask, scales);
- attr.set_int_output_round_mode(round_nearest);
- mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
- NDArray in_buffer = inputs[0];
- if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
- in_buffer = inputs[0].Reorder2Default();
+class SgMKLDNNDequantizeOperator {
+ public:
+ explicit SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs &attrs)
+ : param_(nnvm::get<DequantizeParam>(attrs.parsed)) {}
+ void Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req, const std::vector<NDArray>
&outputs);
+
+ private:
+ bool initalized_{false};
Review comment:
Seems I fail to spell this word long time ago. Thanks for pointing this.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services