haojin2 closed pull request #11631: Fix batchnorm problem with sparse matrices 
when fix_gamma=True
URL: https://github.com/apache/incubator-mxnet/pull/11631
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/batch_norm_v1.cc b/src/operator/batch_norm_v1.cc
index 5da4af25368..cefadc481a2 100644
--- a/src/operator/batch_norm_v1.cc
+++ b/src/operator/batch_norm_v1.cc
@@ -89,6 +89,9 @@ the output. It is often used during inference.
 Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is 
true,
 then set ``gamma`` to 1 and its gradient to 0.
 
+There's no sparse support for this operator, and will exhibit problematic 
behavior if used with
+sparse tensors.
+
 )code" ADD_FILELINE)
 .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
 .add_argument("gamma", "NDArray-or-Symbol", "gamma array")
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 1f9e8289f4a..6548f4692ac 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -444,6 +444,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
   }
   FallBackCompute(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
 }
+#endif
 
 static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
                                         const int dev_mask,
@@ -452,8 +453,32 @@ static inline bool BatchNormStorageType(const 
nnvm::NodeAttrs &attrs,
                                         std::vector<int> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 5);
   CHECK_EQ(out_attrs->size(), 3);
-  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
-                           in_attrs, out_attrs);
+  const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
+
+  if ((common::ContainsStorageType(*in_attrs, kRowSparseStorage) ||
+       common::ContainsStorageType(*in_attrs, kCSRStorage)) &&
+      param.fix_gamma) {
+    LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. 
Tracked at #11647";
+  }
+  for (int& v : *in_attrs)
+    if (v == - 1) v = kDefaultStorage;
+  bool dispatched = false;
+  if (!dispatched && (common::ContainsStorageType(*in_attrs, 
kRowSparseStorage) ||
+                      common::ContainsStorageType(*in_attrs, kCSRStorage))) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+#if MXNET_USE_MKLDNN == 1
+  if (!dispatched) {
+    dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
+                                   in_attrs, out_attrs);
+  }
+#else
+  if (!dispatched) {
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+#endif
+  return dispatched;
 }
 
 static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
@@ -461,10 +486,33 @@ static inline bool backward_BatchNormStorageType(const 
nnvm::NodeAttrs &attrs,
                                                  DispatchMode *dispatch_mode,
                                                  std::vector<int> *in_attrs,
                                                  std::vector<int> *out_attrs) {
-  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
-                           in_attrs, out_attrs);
-}
+  const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
+
+  if ((common::ContainsStorageType(*in_attrs, kRowSparseStorage) ||
+       common::ContainsStorageType(*in_attrs, kCSRStorage)) &&
+      param.fix_gamma) {
+    LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. 
Tracked at #11647";
+  }
+  for (int& v : *in_attrs)
+    if (v == - 1) v = kDefaultStorage;
+  bool dispatched = false;
+  if (!dispatched && (common::ContainsStorageType(*in_attrs, 
kRowSparseStorage) ||
+                      common::ContainsStorageType(*in_attrs, kCSRStorage))) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+#if MXNET_USE_MKLDNN == 1
+  if (!dispatched) {
+    dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
+                                   in_attrs, out_attrs);
+  }
+#else
+  if (!dispatched) {
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
 #endif
+  return dispatched;
+}
 
 std::vector<nnvm::NodeEntry> BatchNormGrad(const nnvm::NodePtr& n,
                                            const std::vector<nnvm::NodeEntry>& 
ograds) {
@@ -552,6 +600,11 @@ axis to be the last item in the input shape.
 Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is 
true,
 then set ``gamma`` to 1 and its gradient to 0.
 
+Note::
+
+When fix_gamma is set to True, no sparse support is provided. If fix_gamma is 
set to False,
+the sparse tensors will fallback.
+
 )code" ADD_FILELINE)
 .set_num_inputs(5)
 .set_num_outputs(3)
@@ -574,9 +627,7 @@ then set ``gamma`` to 1 and its gradient to 0.
 })
 .set_attr<nnvm::FInferShape>("FInferShape", BatchNormShape)
 .set_attr<nnvm::FInferType>("FInferType", BatchNormType)
-#if MXNET_USE_MKLDNN == 1
 .set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
-#endif
 .set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 8c296deef20..9aa93b25b27 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -233,7 +233,7 @@ def check_batchnorm_training(stype):
                            mx.nd.array(beta).tostype(stype)]
             mean_std = [mx.nd.array(rolling_mean).tostype(stype), 
mx.nd.array(rolling_std).tostype(stype)]
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
     stypes = ['row_sparse', 'default']
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 652960191c7..cb4bbaf0b3e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1527,9 +1527,7 @@ def check_batchnorm_training(stype):
                 test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-    stypes = ['default']
-    for stype in stypes:
-        check_batchnorm_training(stype)
+    check_batchnorm_training('default')
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 95689b785db..e51a49424c8 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -16,7 +16,8 @@
 # under the License.
 
 from mxnet.test_utils import *
-from common import setup_module, with_seed, teardown
+from mxnet.base import MXNetError
+from common import setup_module, with_seed, teardown, assertRaises
 import random
 import warnings
 
@@ -2098,6 +2099,78 @@ def check_scatter_ops(name, shape, lhs_stype, rhs_stype, 
forward_mxnet_call, for
                           lambda l, r: l + r,
                           rhs_is_scalar=True, verbose=False, density=0.5)
 
+
+@with_seed()
+def test_batchnorm_fallback():
+    # same test as test_operator.test_batchnorm_training, but tests fallback 
logic of batchnorm
+    stype = 'row_sparse'
+    for shape in [(2, 3), (2, 3, 2, 2)]:
+        data_tmp = np.random.normal(-0.1, 0.1, size=shape)
+        s = shape[1],
+        gamma = np.ones(s)
+        beta = np.ones(s)
+        gamma[1] = 3
+        beta[0] = 3
+
+        rolling_mean = np.random.uniform(size=s)
+        rolling_std = np.random.uniform(size=s)
+
+        data = mx.symbol.Variable('data', stype=stype)
+        in_location = [mx.nd.array(data_tmp).tostype(stype), 
mx.nd.array(gamma).tostype(stype),
+                        mx.nd.array(beta).tostype(stype)]
+        mean_std = [mx.nd.array(rolling_mean).tostype(stype), 
mx.nd.array(rolling_std).tostype(stype)]
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=False)
+        check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, 
rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True)
+        check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, 
rtol=0.16, atol=1e-2)
+
+        # Test varying channel axis
+        dim = len(shape)
+        for chaxis in range(-dim, dim):
+            chaxis_true = chaxis
+            if chaxis < 0:
+                chaxis_true = dim + chaxis
+
+            shapex = shape
+
+            channel_count = shapex[chaxis_true]
+            data_tmp = np.random.normal(-0.1, 0.1, size=shapex)
+
+            gamma = np.ones(channel_count)
+            beta = np.ones(channel_count)
+            if channel_count > 1:
+                gamma[1] = 3
+            beta[0] = 3
+
+            in_location = [mx.nd.array(data_tmp).tostype(stype), 
mx.nd.array(gamma).tostype(stype),
+                            mx.nd.array(beta).tostype(stype)]
+
+            xrolling_mean = np.random.uniform(size=channel_count)
+            xrolling_std = np.random.uniform(size=channel_count)
+            xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
+                            mx.nd.array(xrolling_std).tostype(stype)]
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, 
use_global_stats=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+            check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True, axis=chaxis)
+            check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
+
+
 @with_seed()
 def test_mkldnn_sparse():
     # This test is trying to create a race condition describedd in


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to