This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 53d786d  [MXNET-952] Check for correlation kernel size along with 
unittest (#12558)
53d786d is described below

commit 53d786dc6e3485714c9d1c23a4b2e3e1857922e7
Author: Chaitanya Prakash Bapat <[email protected]>
AuthorDate: Tue Sep 18 10:53:46 2018 -0700

    [MXNET-952] Check for correlation kernel size along with unittest (#12558)
    
    * added a line to check for kernel size and unittest for the same
    
    * Update correlation-inl.h
    
    * Update test_operator.py
    
    * fix space and long line issue
---
 src/operator/correlation-inl.h         |  1 +
 tests/python/unittest/test_operator.py | 18 ++++++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/src/operator/correlation-inl.h b/src/operator/correlation-inl.h
index 9dca44e..e1cc972 100644
--- a/src/operator/correlation-inl.h
+++ b/src/operator/correlation-inl.h
@@ -78,6 +78,7 @@ class CorrelationOp : public Operator {
     using namespace mshadow;
     CHECK_EQ(in_data.size(), 2U);
     CHECK_EQ(out_data.size(), 3U);
+    CHECK_NE(param_.kernel_size % 2, 0) << "kernel size should be odd number";
     Stream<xpu> *s = ctx.get_stream<xpu>();
     Tensor<xpu, 4, DType> data1 = in_data[Correlation::kData1].get<xpu, 4, 
DType>(s);
     Tensor<xpu, 4, DType> data2 = in_data[Correlation::kData2].get<xpu, 4, 
DType>(s);
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 2bf7e84..5937ecd 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6913,6 +6913,24 @@ def test_spacetodepth():
     test_invalid_block_size()
     test_invalid_depth_dim()
 
+@with_seed()
+def test_invalid_kernel_size():
+    invalid_kernel_size = 28
+    assert_exception(
+        mx.nd.Correlation,
+        MXNetError,
+        mx.nd.array(np.random.rand(1, 1, 28, 28)),
+        mx.nd.array(np.random.rand(1, 1, 28, 28)),
+        kernel_size=invalid_kernel_size)
+
+@with_seed()
+def test_valid_kernel_size():
+    valid_kernel_size = 9
+    mx.nd.Correlation(
+        mx.nd.array(np.random.rand(1, 1, 28, 28)),
+        mx.nd.array(np.random.rand(1, 1, 28, 28)),
+        kernel_size=valid_kernel_size)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to