anko-intel commented on a change in pull request #20533:
URL: https://github.com/apache/incubator-mxnet/pull/20533#discussion_r691227672
##########
File path: src/operator/nn/mkldnn/mkldnn_fully_connected.cc
##########
@@ -42,9 +42,10 @@ mkldnn::inner_product_forward::primitive_desc
GetFCFwdImpl(const MKLDNNFCFullPar
const NDArray* bias,
const
mkldnn::memory::desc& out_md) {
auto data_md = GetMemDesc(data);
- auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight,
mshadow::kInt8)
- : GetFCWeightDesc(weight);
- auto engine = CpuEngine::Get()->get_engine();
+ auto weight_md = full_param.mkldnn_param.quantized
+ ? GetFCWeightDesc(weight, data.shape()[0],
mshadow::kInt8)
+ : GetFCWeightDesc(weight, data.shape()[0]);
+ auto engine = CpuEngine::Get()->get_engine();
Review comment:
This line could stay aligned on "="
##########
File path: src/operator/nn/mkldnn/mkldnn_base-inl.h
##########
@@ -305,17 +305,26 @@ inline static mkldnn::memory::desc GetMemDesc(const
NDArray& arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype),
mkldnn::memory::format_tag::any};
}
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int
dtype = -1) {
+inline static bool SupportBRGEMMImpl(mkldnn::memory::dims weight_dims, size_t
batch_size) {
+ return weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0 &&
weight_dims[0] >= 1024 &&
+ weight_dims[1] >= 1024 && batch_size >= 2 << 13;
+}
+
+inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
+ size_t batch_size,
+ int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++)
dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
- const bool brgemm_disabled = dmlc::GetEnv("MXNET_MKLDNN_DISABLE_BRGEMM_FC",
true);
- if (dims.size() == 2 && brgemm_disabled) {
- format = mkldnn::memory::format_tag::ab;
+ const bool force_fc_ab_format =
dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);
Review comment:
please update docs/static_site/src/pages/api/faq/env_var.md
##########
File path: src/operator/nn/mkldnn/mkldnn_base-inl.h
##########
@@ -305,17 +305,26 @@ inline static mkldnn::memory::desc GetMemDesc(const
NDArray& arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype),
mkldnn::memory::format_tag::any};
}
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int
dtype = -1) {
+inline static bool SupportBRGEMMImpl(mkldnn::memory::dims weight_dims, size_t
batch_size) {
Review comment:
```suggestion
inline static bool ChooseBRGEMMImpl(mkldnn::memory::dims weight_dims, size_t
batch_size) {
```
and maybe some comments that this function takes brgemm implementation when
it it is faster than ...
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]