eric-haibin-lin closed pull request #11498: Fix InferStorage for sparse fallback in FullyConnected URL: https://github.com/apache/incubator-mxnet/pull/11498
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index 7eba2e20e57..2338f8974aa 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -35,6 +35,7 @@ #include "../operator_common.h" #include "../elemwise_op_common.h" #include "../linalg.h" +#include "../../common/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 48d479ccf60..46772a4db18 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -210,17 +210,21 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), out_expected); - DispatchMode wanted_mode; -#if 0 + bool dispatched = false; // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. // It seems there is a bug. - if (dev_mask == mshadow::cpu::kDevMask) - *dispatch_mode = DispatchMode::kFComputeEx; - else -#endif - wanted_mode = DispatchMode::kFCompute; - return storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, wanted_mode); + if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) { + storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } + if (!dispatched) { + dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + return dispatched; } DMLC_REGISTER_PARAMETER(FullyConnectedParam); diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 8c296deef20..a6d7743e926 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -22,7 +22,7 @@ import os import numpy as np import mxnet as mx -from mxnet.test_utils import assert_almost_equal +from mxnet.test_utils import rand_ndarray, assert_almost_equal from mxnet import gluon from mxnet.gluon import nn from mxnet.test_utils import * @@ -240,5 +240,25 @@ def check_batchnorm_training(stype): for stype in stypes: check_batchnorm_training(stype) + +@with_seed() +def test_fullyconnected(): + def check_fullyconnected_training(stype): + data_shape = rand_shape_nd(2) + weight_shape = rand_shape_nd(2) + weight_shape = (weight_shape[0], data_shape[1]) + for density in [1.0, 0.5, 0.0]: + x = rand_ndarray(shape=data_shape, stype=stype, density=density) + w = rand_ndarray(shape=weight_shape, stype=stype, density=density) + x_sym = mx.sym.Variable("data") + w_sym = mx.sym.Variable("weight") + sym = mx.sym.FullyConnected(data=x_sym, weight=w_sym, num_hidden=weight_shape[0], no_bias=True) + in_location = [x, w] + check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3) + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_fullyconnected_training(stype) + + if __name__ == '__main__': test_mkldnn_install() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services