This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 53d786d [MXNET-952] Check for correlation kernel size along with
unittest (#12558)
53d786d is described below
commit 53d786dc6e3485714c9d1c23a4b2e3e1857922e7
Author: Chaitanya Prakash Bapat <[email protected]>
AuthorDate: Tue Sep 18 10:53:46 2018 -0700
[MXNET-952] Check for correlation kernel size along with unittest (#12558)
* added a line to check for kernel size and unittest for the same
* Update correlation-inl.h
* Update test_operator.py
* fix space and long line issue
---
src/operator/correlation-inl.h | 1 +
tests/python/unittest/test_operator.py | 18 ++++++++++++++++++
2 files changed, 19 insertions(+)
diff --git a/src/operator/correlation-inl.h b/src/operator/correlation-inl.h
index 9dca44e..e1cc972 100644
--- a/src/operator/correlation-inl.h
+++ b/src/operator/correlation-inl.h
@@ -78,6 +78,7 @@ class CorrelationOp : public Operator {
using namespace mshadow;
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(out_data.size(), 3U);
+ CHECK_NE(param_.kernel_size % 2, 0) << "kernel size should be odd number";
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> data1 = in_data[Correlation::kData1].get<xpu, 4,
DType>(s);
Tensor<xpu, 4, DType> data2 = in_data[Correlation::kData2].get<xpu, 4,
DType>(s);
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 2bf7e84..5937ecd 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6913,6 +6913,24 @@ def test_spacetodepth():
test_invalid_block_size()
test_invalid_depth_dim()
+@with_seed()
+def test_invalid_kernel_size():
+ invalid_kernel_size = 28
+ assert_exception(
+ mx.nd.Correlation,
+ MXNetError,
+ mx.nd.array(np.random.rand(1, 1, 28, 28)),
+ mx.nd.array(np.random.rand(1, 1, 28, 28)),
+ kernel_size=invalid_kernel_size)
+
+@with_seed()
+def test_valid_kernel_size():
+ valid_kernel_size = 9
+ mx.nd.Correlation(
+ mx.nd.array(np.random.rand(1, 1, 28, 28)),
+ mx.nd.array(np.random.rand(1, 1, 28, 28)),
+ kernel_size=valid_kernel_size)
+
if __name__ == '__main__':
import nose
nose.runmodule()