eric-haibin-lin closed pull request #11498: Fix InferStorage for sparse 
fallback in FullyConnected
URL: https://github.com/apache/incubator-mxnet/pull/11498
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/fully_connected-inl.h 
b/src/operator/nn/fully_connected-inl.h
index 7eba2e20e57..2338f8974aa 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -35,6 +35,7 @@
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
 #include "../linalg.h"
+#include "../../common/utils.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/nn/fully_connected.cc 
b/src/operator/nn/fully_connected.cc
index 48d479ccf60..46772a4db18 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -210,17 +210,21 @@ inline static bool BackwardFCStorageType(const 
nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), out_expected);
 
-  DispatchMode wanted_mode;
-#if 0
+  bool dispatched = false;
   // TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
   // It seems there is a bug.
-  if (dev_mask == mshadow::cpu::kDevMask)
-    *dispatch_mode = DispatchMode::kFComputeEx;
-  else
-#endif
-    wanted_mode = DispatchMode::kFCompute;
-  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
-                             dispatch_mode, wanted_mode);
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, 
mxnet::kDefaultStorage)) {
+    storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                        dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && common::ContainsStorageType(*in_attrs, 
mxnet::kRowSparseStorage)) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  if (!dispatched) {
+    dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  return dispatched;
 }
 
 DMLC_REGISTER_PARAMETER(FullyConnectedParam);
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 8c296deef20..a6d7743e926 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -22,7 +22,7 @@
 import os
 import numpy as np
 import mxnet as mx
-from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import rand_ndarray, assert_almost_equal
 from mxnet import gluon
 from mxnet.gluon import nn
 from mxnet.test_utils import *
@@ -240,5 +240,25 @@ def check_batchnorm_training(stype):
     for stype in stypes:
         check_batchnorm_training(stype)
 
+
+@with_seed()
+def test_fullyconnected():
+    def check_fullyconnected_training(stype):
+        data_shape = rand_shape_nd(2)
+        weight_shape = rand_shape_nd(2)
+        weight_shape = (weight_shape[0], data_shape[1])
+        for density in [1.0, 0.5, 0.0]:
+            x = rand_ndarray(shape=data_shape, stype=stype, density=density)
+            w = rand_ndarray(shape=weight_shape, stype=stype, density=density)
+            x_sym = mx.sym.Variable("data")
+            w_sym = mx.sym.Variable("weight")
+            sym = mx.sym.FullyConnected(data=x_sym, weight=w_sym, 
num_hidden=weight_shape[0], no_bias=True)
+            in_location = [x, w]
+            check_numeric_gradient(sym, in_location, numeric_eps=1e-3, 
rtol=1e-3, atol=5e-3)
+    stypes = ['row_sparse', 'default']
+    for stype in stypes:
+        check_fullyconnected_training(stype)
+
+
 if __name__ == '__main__':
     test_mkldnn_install()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to