This is an automated email from the ASF dual-hosted git repository.
samskalicky pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.7.x by this push:
new 651e24b [v1.7.x] backport Invoke mkldnn and cudnn BatchNorm when
axis != 1 to v1.7.x (#18676)
651e24b is described below
commit 651e24b281e9e2b405808bd4ec0423a4e6a5a946
Author: Jake Lee <[email protected]>
AuthorDate: Mon Aug 10 11:37:35 2020 -0700
[v1.7.x] backport Invoke mkldnn and cudnn BatchNorm when axis != 1 to
v1.7.x (#18676)
* [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (#18504)
* fix batch norm when fix_gamma is True
* support gradient accumulation for batch norm
* mkldnn batchnorm support grad add
* unittest for bn
* fix bn arg
* fix lint
* fix mkldnn
* fix mkldnn bn
* fix grad when fixing gamma
* fix naive gpu bn
* fix lint
* invoke mkldnn and cudnn batchnorm when axis != 1
* backport 18500
* change condition
* fix
* fix
* add mkldnn_off for bn
* remove mkldnn_off
* recover save_000800.json
* cast
* remove and fix flaky test
Co-authored-by: JackieWu <[email protected]>
---
src/operator/nn/batch_norm.cc | 12 ++++---
src/operator/nn/batch_norm.cu | 6 ++--
src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 26 +++++++++++----
src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 44 +++++++++++++++++++++++---
tests/python/unittest/test_numpy_op.py | 2 +-
tests/python/unittest/test_operator.py | 2 +-
tests/python/unittest/test_symbol.py | 30 ------------------
7 files changed, 70 insertions(+), 52 deletions(-)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index b8961df..42cb6c2 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -422,10 +422,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam
¶m) {
- mxnet::TShape shape = input.shape();
- return SupportMKLDNN(input) && shape.ndim() == 4
- && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
- && !mxnet::op::batchnorm::disable_mkl;
+ if (mxnet::op::batchnorm::disable_mkl) return false;
+ const mxnet::TShape shape = input.shape();
+ const int ndim = shape.ndim();
+ if (ndim == 0 || shape.Size() == 0) return false;
+ const int dtype = input.dtype();
+ return (dtype == mshadow::kFloat32 ||
+ dtype == mshadow::kBfloat16) &&
+ SupportStorageMKLDNN(input.storage_type());
}
void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 7b36d25..40d677a 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -704,8 +704,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
- if (!param.use_global_stats && !param.cudnn_off
- && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
+ if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
})
@@ -733,8 +732,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
- if (!param.use_global_stats && !param.cudnn_off
- && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
+ if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
})
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index fc91212..797234c 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -260,15 +260,27 @@ class CuDNNBatchNormOp {
private:
void Init(const TBlob &in_data) {
- if (in_data.ndim() == 4) {
- for (int i = 0; i < 4; ++i)
- shape_[i] = in_data.shape_[i];
+ CHECK_GE(param_.axis, 0);
+ CHECK_LT(param_.axis, in_data.ndim());
+ if (param_.axis == 1) {
+ if (in_data.ndim() == 4) {
+ for (int i = 0; i < 4; ++i)
+ shape_[i] = in_data.shape_[i];
+ } else {
+ // when in_data.ndim() != 4
+ shape_[0] = in_data.shape_[0];
+ shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
+ shape_[2] = 1;
+ shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2,
+ in_data.ndim()));
+ }
} else {
- // when in_data.ndim() != 4
- shape_[0] = in_data.shape_[0];
- shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
+ // reshape to (N, C, 1, D), C is the `param_.axis` dimension
+ shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
+ shape_[1] = in_data.shape_[param_.axis];
shape_[2] = 1;
- shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
+ shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1,
+ in_data.ndim()));
}
CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 2021ba0..18055ca 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -151,7 +151,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs, const
std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
- const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() +
batchnorm::kInMovingMean);
+ std::vector<NDArray> in_data(inputs.begin(), inputs.begin() +
batchnorm::kInMovingMean);
+
+ mxnet::TShape shape = inputs[batchnorm::kData].shape();
+ const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
+ CHECK_LT(real_axis, shape.ndim());
+ NDArray out = outputs[batchnorm::kOut];
+ if (param.axis != 1 || shape.ndim() != 4) {
+ // reshape to (N, C, 1, D)
+ mxnet::TShape new_shape{
+ static_cast<dim_t>(shape.ProdShape(0, real_axis)),
+ shape[real_axis],
+ 1,
+ static_cast<dim_t>(shape.ProdShape(real_axis + 1,
+ static_cast<int>(shape.ndim())))
+ };
+ in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
+ out = out.Reshape(new_shape);
+ }
+
const std::vector<NDArray> aux_states(inputs.begin() +
batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
mkldnn::normalization_flags flags = _GetFlags(in_data,
@@ -160,7 +178,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
ctx.is_train &&
!param.use_global_stats);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
- const NDArray &out = outputs[batchnorm::kOut];
// for output memory
auto out_mem = const_cast<NDArray
&>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
@@ -304,9 +321,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
param,
ctx.is_train &&
!param.use_global_stats);
- const NDArray &data = in_data[batchnorm::kData];
- const NDArray &diff = out_grad[batchnorm::kOut];
- const NDArray &gradIn = in_grad[batchnorm::kData];
+ NDArray data = in_data[batchnorm::kData];
+ NDArray diff = out_grad[batchnorm::kOut];
+ NDArray gradIn = in_grad[batchnorm::kData];
const NDArray &moving_mean = aux_states[batchnorm::kMovingMean];
const NDArray &moving_var = aux_states[batchnorm::kMovingVar];
const NDArray &out_mean = out_data[batchnorm::kMean];
@@ -317,6 +334,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
CHECK(moving_mean.IsDefaultData());
CHECK(moving_var.IsDefaultData());
+ mxnet::TShape shape = data.shape();
+ const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
+ CHECK_LT(real_axis, shape.ndim());
+ if (param.axis != 1 || shape.ndim() != 4) {
+ // reshape to (N, C, 1, D)
+ mxnet::TShape new_shape{
+ static_cast<dim_t>(shape.ProdShape(0, real_axis)),
+ shape[real_axis],
+ 1,
+ static_cast<dim_t>(shape.ProdShape(real_axis + 1,
+ static_cast<int>(shape.ndim())))
+ };
+ data = data.Reshape(new_shape);
+ diff = diff.Reshape(new_shape);
+ gradIn = gradIn.Reshape(new_shape);
+ }
+
auto data_mem = data.GetMKLDNNData();
auto diff_mem = diff.GetMKLDNNData();
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
diff --git a/tests/python/unittest/test_numpy_op.py
b/tests/python/unittest/test_numpy_op.py
index 4811b34..b218c48 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1541,7 +1541,7 @@ def test_npx_batch_norm():
assert_almost_equal(
bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol,
rtol=rtol)
- shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
+ shapes = [(4, 2), (4, 3, 4), (4, 6, 4, 5), (4, 5, 6, 4, 5)]
bools = [False, True]
for shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
shapes, bools, bools, bools):
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index b9e2422..de47875 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1964,7 +1964,7 @@ def test_batchnorm():
bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol,
rtol=rtol)
op_names = ['BatchNorm', 'SyncBatchNorm']
- shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
+ shapes = [(4, 2), (4, 3, 4), (4, 6, 4, 5), (4, 5, 6, 4, 5)]
bools = [False, True]
for op_name, shape, fix_gamma, cudnn_off, output_mean_var in
itertools.product(
op_names, shapes, bools, bools, bools):
diff --git a/tests/python/unittest/test_symbol.py
b/tests/python/unittest/test_symbol.py
index f9d6cea..00b27df 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -272,36 +272,6 @@ def check_symbol_consistency(sym1, sym2, ctx,
skip_grad=False, equal_nan=False):
grad_req='null' if skip_grad else 'write',
equal_nan=equal_nan)
-def test_load_000800():
- with mx.AttrScope(ctx_group='stage1'):
- data = mx.symbol.Variable('data', lr_mult=0.2)
- weight = mx.sym.Variable(name='fc1_weight', lr_mult=1.2)
- fc1 = mx.symbol.FullyConnected(data = data, weight=weight,
name='fc1', num_hidden=128, wd_mult=0.3)
- act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-
- set_stage1 = set(act1.list_arguments())
- with mx.AttrScope(ctx_group='stage2'):
- fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden
= 64, lr_mult=0.01)
- act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
- fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
- fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0')
- sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
-
- curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
- sym2 = mx.sym.load(os.path.join(curr_path, 'save_000800.json'))
-
- attr1 = sym1.attr_dict()
- attr2 = sym2.attr_dict()
- for k, v1 in attr1.items():
- assert k in attr2, k
- v2 = attr2[k]
- for kk, vv1 in v1.items():
- if kk.startswith('__') and kk.endswith('__'):
- assert kk in v2 and v2[kk] == vv1, k + str(v1) + str(v2)
-
- check_symbol_consistency(sym1, sym2,
- {'ctx': mx.cpu(0), 'group2ctx': {'stage1' : mx.cpu(1), 'stage2' :
mx.cpu(2)}, 'data': (1,200)})
-
def test_blockgrad():
a = mx.sym.Variable('a')