This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b98299  [Backport] Enabling BRGEMM FullyConnected based on shapes 
(#20568)
1b98299 is described below

commit 1b98299a5c244c3701c7df1ec20b73f30f75b8c9
Author: bgawrych <[email protected]>
AuthorDate: Mon Sep 6 13:36:28 2021 +0200

    [Backport] Enabling BRGEMM FullyConnected based on shapes (#20568)
    
    * [v1.x][Feature] Add flag for disabling oneDNN BRGEMM implementation of FC 
(#20450)
    
    * Add flag for disabling oneDNN BRGEMM implementation of FC
    
    * Review fixes
    
    * Update env_var.md
    
    * [v1.x] Enabling BRGEMM FullyConnected based on shapes (#20533)
    
    * Enable brgemm based on input info
    
    * fix sanity
    
    * Review fixes
    
    * Change function name
    
    * Fix typo
    
    * Align variable assignments
    
    * Fix review
    
    * use const reference
    
    * Update flag name
---
 docs/static_site/src/pages/api/faq/env_var.md    |  4 ++++
 src/operator/nn/mkldnn/mkldnn_base-inl.h         | 16 ++++++++++++++--
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 11 ++++++-----
 3 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/docs/static_site/src/pages/api/faq/env_var.md 
b/docs/static_site/src/pages/api/faq/env_var.md
index d5234d4..a4b4915 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -333,6 +333,10 @@ If ctypes is used, it must be 
`mxnet._ctypes.ndarray.NDArrayBase`.
   - Values: Int ```(default=-1)```
   - Flag to set num of elements that ONEDNN cache can hold. Default is -1 
which means cache size is unbounded. Should only be set if your model has 
variable input shapes, as cache size may grow unbounded. The number represents 
the number of items in the cache and is proportional to the number of layers 
that use ONEDNN and different input shape.
 
+* MXNET_ONEDNN_FORCE_FC_AB_FORMAT
+  - Values: 0, 1 ```(default=0)```
+  - If set to true, FullyConnected will use only AB format for weights, thus 
MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI 
support which requires special weights format.
+
 * MXNET_ENFORCE_DETERMINISM
   - Values: 0(false) or 1(true) ```(default=0)```
   - If set to true, MXNet will only use deterministic algorithms in forward 
and backward computation.
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 2ee0793..2af0b5f 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -306,7 +306,16 @@ inline static mkldnn::memory::desc GetMemDesc(const 
NDArray& arr, int dtype = -1
   return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), 
mkldnn::memory::format_tag::any};
 }
 
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int 
dtype = -1) {
+inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, 
size_t batch_size) {
+  // Conditions based on measurement results done on CLX8280
+  // https://github.com/apache/incubator-mxnet/pull/20533
+  return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 
16384 &&
+         weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
+}
+
+inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
+                                                   size_t batch_size,
+                                                   int dtype = -1) {
   int ndim = arr.shape().ndim();
   mkldnn::memory::dims dims(ndim);
   dtype = (dtype == -1) ? arr.dtype() : dtype;
@@ -314,8 +323,11 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const 
NDArray& arr, int dtype
     dims[i] = arr.shape()[i];
   auto format = mkldnn::memory::format_tag::any;
   // for batch 256 alexnet benchmark test
+  const bool force_fc_ab_format = 
dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
   if (dims.size() == 2) {
-    format = mkldnn::memory::format_tag::ab;
+    if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
+      format = mkldnn::memory::format_tag::ab;
+    }
   }
 
   return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc 
b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index a215d28..4bd0b94 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -38,10 +38,11 @@ mkldnn::inner_product_forward::primitive_desc 
GetFCFwdImpl(const MKLDNNFCFullPar
                                                            const NDArray& 
weight,
                                                            const NDArray* bias,
                                                            const 
mkldnn::memory::desc& out_md) {
-  auto data_md   = GetMemDesc(data);
-  auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight, 
mshadow::kInt8)
-                                                     : GetFCWeightDesc(weight);
   auto engine    = CpuEngine::Get()->get_engine();
+  auto data_md   = GetMemDesc(data);
+  auto weight_md = full_param.mkldnn_param.quantized
+                       ? GetFCWeightDesc(weight, data.shape()[0], 
mshadow::kInt8)
+                       : GetFCWeightDesc(weight, data.shape()[0]);
   auto propagation =
       is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
 
@@ -92,7 +93,7 @@ inline static 
mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
     const NDArray& output,
     mkldnn::inner_product_forward::primitive_desc fwd_pd) {
   auto data_md   = GetMemDesc(data);
-  auto weight_md = GetFCWeightDesc(weight);
+  auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
   auto out_md    = GetMemDesc(output);
   auto engine    = CpuEngine::Get()->get_engine();
   mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
@@ -106,7 +107,7 @@ inline static 
mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
     const NDArray& output,
     mkldnn::inner_product_forward::primitive_desc fwd_pd) {
   auto data_md   = GetMemDesc(data);
-  auto weight_md = GetFCWeightDesc(weight);
+  auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
   auto out_md    = GetMemDesc(output);
   auto engine    = CpuEngine::Get()->get_engine();
   if (bias) {

Reply via email to