marcoabreu closed pull request #11656: Fix batchnorm problem with sparse 
matrices when fix_gamma=True
URL: https://github.com/apache/incubator-mxnet/pull/11656
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/batch_norm_v1.cc b/src/operator/batch_norm_v1.cc
index 5da4af25368..2d19107eda1 100644
--- a/src/operator/batch_norm_v1.cc
+++ b/src/operator/batch_norm_v1.cc
@@ -89,6 +89,9 @@ the output. It is often used during inference.
 Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is 
true,
 then set ``gamma`` to 1 and its gradient to 0.
 
+There's no sparse support for this operator, and it will exhibit problematic 
behavior if used with
+sparse tensors.
+
 )code" ADD_FILELINE)
 .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
 .add_argument("gamma", "NDArray-or-Symbol", "gamma array")
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 1f9e8289f4a..30fb665dd05 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -321,6 +321,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
   const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
   using namespace mshadow;
   CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, 
MovingVar]";
+  CHECK_EQ(out_shape->size(), 3U);
   const TShape &dshape = in_shape->at(batchnorm::kData);
 
   const size_t channelAxis = static_cast<size_t>(param.axis < 0
@@ -444,27 +445,37 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs 
&attrs,
   }
   FallBackCompute(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
 }
+#endif
 
 static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
                                         const int dev_mask,
                                         DispatchMode *dispatch_mode,
                                         std::vector<int> *in_attrs,
                                         std::vector<int> *out_attrs) {
-  CHECK_EQ(in_attrs->size(), 5);
-  CHECK_EQ(out_attrs->size(), 3);
-  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
-                           in_attrs, out_attrs);
-}
+  const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
 
-static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
-                                                 const int dev_mask,
-                                                 DispatchMode *dispatch_mode,
-                                                 std::vector<int> *in_attrs,
-                                                 std::vector<int> *out_attrs) {
-  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
-                           in_attrs, out_attrs);
-}
+  bool dispatched = false;
+#if MXNET_USE_MKLDNN == 1
+  if (!dispatched) {
+    dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
+                                   in_attrs, out_attrs);
+  }
+#else
+  for (int& v : *in_attrs)
+    if (v == - 1) v = kDefaultStorage;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
 #endif
+  if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && 
param.fix_gamma) {
+    LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. 
Tracked at #11647";
+  }
+  return dispatched;
+}
 
 std::vector<nnvm::NodeEntry> BatchNormGrad(const nnvm::NodePtr& n,
                                            const std::vector<nnvm::NodeEntry>& 
ograds) {
@@ -552,6 +563,11 @@ axis to be the last item in the input shape.
 Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is 
true,
 then set ``gamma`` to 1 and its gradient to 0.
 
+Note::
+
+When fix_gamma is set to True, no sparse support is provided. If fix_gamma is 
set to False,
+the sparse tensors will fallback.
+
 )code" ADD_FILELINE)
 .set_num_inputs(5)
 .set_num_outputs(3)
@@ -574,9 +590,7 @@ then set ``gamma`` to 1 and its gradient to 0.
 })
 .set_attr<nnvm::FInferShape>("FInferShape", BatchNormShape)
 .set_attr<nnvm::FInferType>("FInferType", BatchNormType)
-#if MXNET_USE_MKLDNN == 1
 .set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
-#endif
 .set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
@@ -607,8 +621,8 @@ then set ``gamma`` to 1 and its gradient to 0.
 NNVM_REGISTER_OP(_backward_BatchNorm)
 .set_num_outputs(3)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
 #if MXNET_USE_MKLDNN == 1
-.set_attr<FInferStorageType>("FInferStorageType", 
backward_BatchNormStorageType)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
diff --git a/src/operator/nn/fully_connected.cc 
b/src/operator/nn/fully_connected.cc
index d9099cb57d4..d1d84e97529 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -213,8 +213,8 @@ inline static bool BackwardFCStorageType(const 
nnvm::NodeAttrs& attrs,
   // TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
   // It seems there is a bug.
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, 
mxnet::kDefaultStorage)) {
-    storage_type_assign(out_attrs, mxnet::kDefaultStorage,
-                        dispatch_mode, DispatchMode::kFCompute);
+    dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
   }
   if (!dispatched && common::ContainsStorageType(*in_attrs, 
mxnet::kRowSparseStorage)) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index a6d7743e926..ff9ba538b95 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -233,7 +233,7 @@ def check_batchnorm_training(stype):
                            mx.nd.array(beta).tostype(stype)]
             mean_std = [mx.nd.array(rolling_mean).tostype(stype), 
mx.nd.array(rolling_std).tostype(stype)]
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
     stypes = ['row_sparse', 'default']
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index f9dde2e6d24..6c6ff310519 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1552,9 +1552,7 @@ def check_batchnorm_training(stype):
                 test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-    stypes = ['default']
-    for stype in stypes:
-        check_batchnorm_training(stype)
+    check_batchnorm_training('default')
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 95689b785db..e51a49424c8 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -16,7 +16,8 @@
 # under the License.
 
 from mxnet.test_utils import *
-from common import setup_module, with_seed, teardown
+from mxnet.base import MXNetError
+from common import setup_module, with_seed, teardown, assertRaises
 import random
 import warnings
 
@@ -2098,6 +2099,78 @@ def check_scatter_ops(name, shape, lhs_stype, rhs_stype, 
forward_mxnet_call, for
                           lambda l, r: l + r,
                           rhs_is_scalar=True, verbose=False, density=0.5)
 
+
+@with_seed()
+def test_batchnorm_fallback():
+    # same test as test_operator.test_batchnorm_training, but tests fallback 
logic of batchnorm
+    stype = 'row_sparse'
+    for shape in [(2, 3), (2, 3, 2, 2)]:
+        data_tmp = np.random.normal(-0.1, 0.1, size=shape)
+        s = shape[1],
+        gamma = np.ones(s)
+        beta = np.ones(s)
+        gamma[1] = 3
+        beta[0] = 3
+
+        rolling_mean = np.random.uniform(size=s)
+        rolling_std = np.random.uniform(size=s)
+
+        data = mx.symbol.Variable('data', stype=stype)
+        in_location = [mx.nd.array(data_tmp).tostype(stype), 
mx.nd.array(gamma).tostype(stype),
+                        mx.nd.array(beta).tostype(stype)]
+        mean_std = [mx.nd.array(rolling_mean).tostype(stype), 
mx.nd.array(rolling_std).tostype(stype)]
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=False)
+        check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, 
rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True)
+        check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, 
rtol=0.16, atol=1e-2)
+
+        # Test varying channel axis
+        dim = len(shape)
+        for chaxis in range(-dim, dim):
+            chaxis_true = chaxis
+            if chaxis < 0:
+                chaxis_true = dim + chaxis
+
+            shapex = shape
+
+            channel_count = shapex[chaxis_true]
+            data_tmp = np.random.normal(-0.1, 0.1, size=shapex)
+
+            gamma = np.ones(channel_count)
+            beta = np.ones(channel_count)
+            if channel_count > 1:
+                gamma[1] = 3
+            beta[0] = 3
+
+            in_location = [mx.nd.array(data_tmp).tostype(stype), 
mx.nd.array(gamma).tostype(stype),
+                            mx.nd.array(beta).tostype(stype)]
+
+            xrolling_mean = np.random.uniform(size=channel_count)
+            xrolling_std = np.random.uniform(size=channel_count)
+            xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
+                            mx.nd.array(xrolling_std).tostype(stype)]
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, 
use_global_stats=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+            check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True, axis=chaxis)
+            check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+
 @with_seed()
 def test_mkldnn_sparse():
     # This test is trying to create a race condition describedd in


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to