anirudh2290 closed pull request #12594: [MXNET-867] Pooling1D with "same" padding URL: https://github.com/apache/incubator-mxnet/pull/12594
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 3ae61298de8..55416355d8a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -180,3 +180,4 @@ List of Contributors * [Per Goncalves da Silva](https://github.com/perdasilva) * [Zhijingcheng Yu](https://github.com/jasonyu1996) * [Cheng-Che Lee](https://github.com/stu1130) +* [Chaitanya Bapat](https://github.com/ChaiBapchya) \ No newline at end of file diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h index 8f7a5edc832..33005c8e5f0 100644 --- a/src/operator/nn/pool.h +++ b/src/operator/nn/pool.h @@ -73,7 +73,7 @@ namespace pool_enum { enum PoolingOpInputs {kData}; enum PoolingOpOutputs {kOut, kMask}; enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling}; -enum PoolingOpPadConventionType {kValid, kFull}; +enum PoolingOpPadConventionType {kValid, kFull, kSame}; } // namespace pool_enum /*! diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index ad74a8feae3..71d85da9ba5 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -74,6 +74,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> { DMLC_DECLARE_FIELD(pooling_convention).set_default(pool_enum::kValid) .add_enum("full", pool_enum::kFull) .add_enum("valid", pool_enum::kValid) + .add_enum("same", pool_enum::kSame) .describe("Pooling convention to be applied."); DMLC_DECLARE_FIELD(stride).set_default(TShape()) diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 558722edb20..611568807a9 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -96,6 +96,13 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, CHECK(param.p_value.has_value()); } const TShape &dshape = (*in_shape)[0]; + if (param.pooling_convention == pool_enum::kSame) { + CHECK_EQ(dshape.ndim(), 3U) + << "Pooling: Input data should be 3D in (batch, channel, x)" + << ". Currently 'same' supports Max Pooling 1-D"; + CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0) + << "Same pooling convention disables the use of pad parameter."; + } CHECK_GE(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " @@ -126,11 +133,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, oshape[2] = 1 + (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / param.stride[0]; - } else { + } else if (param.pooling_convention == pool_enum::kFull) { oshape[2] = 1 + static_cast<int>(std::ceil( static_cast<float>(dshape[2] + 2 * param.pad[0] - param.kernel[0]) / param.stride[0])); + } else { + oshape[2] = static_cast<int>(std::ceil( + static_cast<float>(dshape[2] + 2 * param.pad[0]) / + param.stride[0])); } out_shape->clear(); out_shape->push_back(oshape); // save output shape diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 43c357808f1..a7f484e81b3 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6975,6 +6975,40 @@ def test_valid_kernel_size(): mx.nd.array(np.random.rand(1, 1, 28, 28)), kernel_size=valid_kernel_size) +@with_seed() +def test_valid_max_pooling_pad_type_same(): + import math + input_data = mx.nd.array(np.random.rand(1,1,10)) + stride = 2 + kernel = 2 + output_data=mx.nd.Pooling( + input_data, + kernel=kernel, + stride=stride, + pad=(0,0,0), + pool_type='max', + name='pooling', + pooling_convention="same") + assert(math.ceil(input_data.shape[2]/stride) == output_data.shape[2]) + +@with_seed() +def test_invalid_max_pooling_pad_type_same(): + import math + input_data = mx.nd.array(np.random.rand(1,1,10)) + stride = 2 + kernel = 2 + pad = 2 + assert_exception( + mx.nd.Pooling, + MXNetError, + input_data, + stride=stride, + kernel=kernel, + pad=pad, + pool_type='max', + name='pooling', + pooling_convention="same") + if __name__ == '__main__': import nose nose.runmodule() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
