access2rohit opened a new pull request #17805: fixing batch_norm and layer_norm for large tensors URL: https://github.com/apache/incubator-mxnet/pull/17805 ## Description ## Enables large tensor support for following ops: 1. batch_norm 2. layer_norm ## Checklist ## ### Essentials ### Please feel free to remove inapplicable items for your PR. - [x] Changes are complete (i.e. I finished coding on this PR) - [x] All changes have test coverage: - [x] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change ### Proof Of Correctness ### ## layer_norm() ## Before changes: ``` 333 const int channelCount = dshape[channelAxis]; (gdb) info local param = @0x555555cdb770: {<dmlc::Parameter<mxnet::op::BatchNormParam>> = {<No data fields>}, eps = 0.0010000000474974513, momentum = 0.899999976, fix_gamma = true, use_global_stats = false, output_mean_var = false, axis = 0, cudnn_off = false, min_calib_range = {is_none = true, val = {__data = "\000\000\000", __align = {<No data fields>}}}, max_calib_range = {is_none = true, val = {__data = "UU\000", __align = {<No data fields>}}}} dshape = @0x5555572290a0: {<mxnet::Tuple<long>> = {static kStackCache = <optimized out>, ndim_ = 1, num_heap_allocated_ = 0, data_stack_ = {4300000000, 1, 4300000000, 0}, data_heap_ = 0x0}, <No data fields>} channelAxis = 0 channelCount = 21845 <-------- (gdb) p dshape[channelAxis] $1 = (long &) @0x5555572290a8: 4300000000 <-------- (gdb) n 335 if (!mxnet::ndim_is_known(dshape)) { (gdb) p channelCount $2 = 5032704 ``` After Changes: ``` Thread 1 "python3" hit Breakpoint 2, mxnet::op::LayerNormShape (attrs=..., in_shape=0x555556578ff8, out_shape=0x555556579010) at src/operator/nn/layer_norm.cc:50 50 const index_t channelCount = dshape[axis]; (gdb) n 52 if (!mxnet::ndim_is_known(dshape)) { (gdb) info local param = @0x7fffffff9438: {<dmlc::Parameter<mxnet::op::LayerNormParam>> = {<No data fields>}, axis = 0, eps = 9.99999975e-06, output_mean_var = false} dshape = @0x5555565bc420: {<mxnet::Tuple<long>> = {static kStackCache = <optimized out>, ndim_ = 1, num_heap_allocated_ = 0, data_stack_ = {4300000000, 6878235116697514089, 32088647312828786, 0}, data_heap_ = 0x0}, <No data fields>} axis = 0 channelCount = 4300000000 <-------- moments_shape = {<mxnet::Tuple<long>> = {static kStackCache = <optimized out>, ndim_ = -29480, num_heap_allocated_ = 32767, data_stack_ = {140737488326512, 140737488325408, 93825021150800, 140737488325408}, data_heap_ = 0x7fff936c4de7 <std::_Rb_tree<dmlc::parameter::FieldAccessEntry*, dmlc::parameter::FieldAccessEntry*, std::_Identity<dmlc::parameter::FieldAccessEntry*>, std::less<dmlc::parameter::FieldAccessEntry*>, std::allocator<dmlc::parameter::FieldAccessEntry*> >::_Alloc_node::operator()<dmlc::parameter::FieldAccessEntry* const&>(dmlc::parameter::FieldAccessEntry* const&) const+49>}, <No data fields>} (gdb) p dshape[axis] $1 = (long &) @0x5555565bc428: 4300000000 <-------- ``` ## batch_norm() ## Before changes: ``` Thread 1 "python3" hit Breakpoint 1, mxnet::op::LayerNormShape (attrs=..., in_shape=0x555556579dc8, out_shape=0x555556579de0) at src/operator/nn/layer_norm.cc:50 50 const int channelCount = dshape[axis]; (gdb) n 52 if (!mxnet::ndim_is_known(dshape)) { (gdb) p channelCount $3 = 5032704 <-------- (gdb) p dshape[0] $4 = (long &) @0x555556c21f58: 430000000 <-------- (gdb) info local param = @0x7fffffff9418: {<dmlc::Parameter<mxnet::op::LayerNormParam>> = {<No data fields>}, axis = 0, eps = 9.99999975e-06, output_mean_var = false} dshape = @0x555556c21f50: {<mxnet::Tuple<long>> = {static kStackCache = <optimized out>, ndim_ = 1, num_heap_allocated_ = 0, data_stack_ = {4300000000, 0, 0, 0}, data_heap_ = 0x0}, <No data fields>} axis = 0 channelCount = 5032704 moments_shape = {<mxnet::Tuple<long>> = {static kStackCache = <optimized out>, ndim_ = -29512, num_heap_allocated_ = 32767, data_stack_ = {140737488326480, 140737488325376, 93825019642720, 140737488325376}, data_heap_ = 0x7fff936c4de7 <std::_Rb_tree<dmlc::parameter::FieldAccessEntry*, dmlc::parameter::FieldAccessEntry*, std::_Identity<dmlc::parameter::FieldAccessEntry*>, std::less<dmlc::parameter::FieldAccessEntry*>, std::allocator<dmlc::parameter::FieldAccessEntry*> >::_Alloc_node::operator()<dmlc::parameter::FieldAccessEntry* const&>(dmlc::parameter::FieldAccessEntry* const&) const+49>}, <No data fields>} ``` After Changes: ``` Thread 1 "python3" hit Breakpoint 1, mxnet::op::BatchNormShape (attrs=..., in_shape=0x555556579d98, out_shape=0x555556579db0) at src/operator/nn/batch_norm.cc:333 333 const index_t channelCount = dshape[channelAxis]; (gdb) n 335 if (!mxnet::ndim_is_known(dshape)) { (gdb) info local param = @0x555555cdb770: {<dmlc::Parameter<mxnet::op::BatchNormParam>> = {<No data fields>}, eps = 0.0010000000474974513, momentum = 0.899999976, fix_gamma = true, use_global_stats = false, output_mean_var = false, axis = 0, cudnn_off = false, min_calib_range = {is_none = true, val = {__data = "\000\000\000", __align = {<No data fields>}}}, max_calib_range = {is_none = true, val = {__data = "UU\000", __align = {<No data fields>}}}} dshape = @0x5555572290a0: {<mxnet::Tuple<long>> = {static kStackCache = <optimized out>, ndim_ = 1, num_heap_allocated_ = 0, data_stack_ = {4300000000, 1, 4300000000, 0}, data_heap_ = 0x0}, <No data fields>} channelAxis = 0 channelCount = 4300000000 <-------- (gdb) p dshape[channelAxis] $1 = (long &) @0x5555572290a8: 4300000000 <-------- ``` ## Testing ## ``` $ MXNET_TEST_COUNT=1 nosetests --logging-level=DEBUG --verbose -s tests/nightly/test_large_vector.py:test_nn /home/ubuntu/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters test_large_vector.test_nn ... [18:14:51] src/executor/graph_executor.cc:1981: Subgraph backend MKLDNN is activated. [18:21:14] src/executor/graph_executor.cc:1981: Subgraph backend MKLDNN is activated. ok ---------------------------------------------------------------------- Ran 1 test in 1017.457s OK ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
