This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1b98299 [Backport] Enabling BRGEMM FullyConnected based on shapes
(#20568)
1b98299 is described below
commit 1b98299a5c244c3701c7df1ec20b73f30f75b8c9
Author: bgawrych <[email protected]>
AuthorDate: Mon Sep 6 13:36:28 2021 +0200
[Backport] Enabling BRGEMM FullyConnected based on shapes (#20568)
* [v1.x][Feature] Add flag for disabling oneDNN BRGEMM implementation of FC
(#20450)
* Add flag for disabling oneDNN BRGEMM implementation of FC
* Review fixes
* Update env_var.md
* [v1.x] Enabling BRGEMM FullyConnected based on shapes (#20533)
* Enable brgemm based on input info
* fix sanity
* Review fixes
* Change function name
* Fix typo
* Align variable assignments
* Fix review
* use const reference
* Update flag name
---
docs/static_site/src/pages/api/faq/env_var.md | 4 ++++
src/operator/nn/mkldnn/mkldnn_base-inl.h | 16 ++++++++++++++--
src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 11 ++++++-----
3 files changed, 24 insertions(+), 7 deletions(-)
diff --git a/docs/static_site/src/pages/api/faq/env_var.md
b/docs/static_site/src/pages/api/faq/env_var.md
index d5234d4..a4b4915 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -333,6 +333,10 @@ If ctypes is used, it must be
`mxnet._ctypes.ndarray.NDArrayBase`.
- Values: Int ```(default=-1)```
- Flag to set num of elements that ONEDNN cache can hold. Default is -1
which means cache size is unbounded. Should only be set if your model has
variable input shapes, as cache size may grow unbounded. The number represents
the number of items in the cache and is proportional to the number of layers
that use ONEDNN and different input shape.
+* MXNET_ONEDNN_FORCE_FC_AB_FORMAT
+ - Values: 0, 1 ```(default=0)```
+ - If set to true, FullyConnected will use only AB format for weights, thus
MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI
support which requires special weights format.
+
* MXNET_ENFORCE_DETERMINISM
- Values: 0(false) or 1(true) ```(default=0)```
- If set to true, MXNet will only use deterministic algorithms in forward
and backward computation.
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 2ee0793..2af0b5f 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -306,7 +306,16 @@ inline static mkldnn::memory::desc GetMemDesc(const
NDArray& arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype),
mkldnn::memory::format_tag::any};
}
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int
dtype = -1) {
+inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims,
size_t batch_size) {
+ // Conditions based on measurement results done on CLX8280
+ // https://github.com/apache/incubator-mxnet/pull/20533
+ return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >=
16384 &&
+ weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
+}
+
+inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
+ size_t batch_size,
+ int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
@@ -314,8 +323,11 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const
NDArray& arr, int dtype
dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
+ const bool force_fc_ab_format =
dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
- format = mkldnn::memory::format_tag::ab;
+ if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
+ format = mkldnn::memory::format_tag::ab;
+ }
}
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index a215d28..4bd0b94 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -38,10 +38,11 @@ mkldnn::inner_product_forward::primitive_desc
GetFCFwdImpl(const MKLDNNFCFullPar
const NDArray&
weight,
const NDArray* bias,
const
mkldnn::memory::desc& out_md) {
- auto data_md = GetMemDesc(data);
- auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight,
mshadow::kInt8)
- : GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
+ auto data_md = GetMemDesc(data);
+ auto weight_md = full_param.mkldnn_param.quantized
+ ? GetFCWeightDesc(weight, data.shape()[0],
mshadow::kInt8)
+ : GetFCWeightDesc(weight, data.shape()[0]);
auto propagation =
is_train ? mkldnn::prop_kind::forward_training :
mkldnn::prop_kind::forward_scoring;
@@ -92,7 +93,7 @@ inline static
mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
- auto weight_md = GetFCWeightDesc(weight);
+ auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
@@ -106,7 +107,7 @@ inline static
mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
- auto weight_md = GetFCWeightDesc(weight);
+ auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {