This is an automated email from the ASF dual-hosted git repository. anirudh2290 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 74d6692 Fix spatial transformer op (#11806) 74d6692 is described below commit 74d669293f68cf7e9325dcedc569164f3d7cfa90 Author: Anirudh Subramanian <anirudh2...@apache.org> AuthorDate: Mon Jul 23 18:43:32 2018 -0700 Fix spatial transformer op (#11806) --- src/operator/spatial_transformer.cu | 15 +++++++++------ tests/python/unittest/common.py | 7 ++----- tests/python/unittest/test_operator.py | 3 --- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index 4a39733..f1d69f7 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -110,20 +110,21 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h, DType bottom_right_v = 0; // calc input grad if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; + atomicAdd((g_input + data_index), *(grad + grad_index) * top_left_y_w * top_left_x_w); top_left_v = *(data + data_index); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w); + atomicAdd((g_input + data_index + 1), + *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w)); top_right_v = *(data + data_index + 1); } if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w; - bottom_left_v = *(data + data_index + i_w); + atomicAdd((g_input + data_index + i_w), + *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) * - (1.0 - top_left_x_w); + atomicAdd((g_input + data_index + i_w + 1), + *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w)); bottom_right_v = *(data + data_index + i_w + 1); } // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src @@ -157,6 +158,7 @@ inline void BilinearSamplingForward(const Tensor<gpu, 4, DType> &output, cudaStream_t stream = Stream<gpu>::GetStream(output.stream_); BilinearSamplingForwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >( i_c, i_h, i_w, data, grid, o_n, o_c, o_h, o_w, out); + MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingForwardKernel); } template<typename DType> @@ -180,6 +182,7 @@ inline void BilinearSamplingBackward(const Tensor<gpu, 4, DType> &input_grad, cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_); BilinearSamplingBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >( i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src); + MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingBackwardKernel); } } // namespace mshadow diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 65c1886..f98bb79 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -95,7 +95,7 @@ def random_seed(seed=None): random.seed(next_seed) -def assert_raises_cudnn_disabled(assertion_error=False): +def assert_raises_cudnn_disabled(): def test_helper(orig_test): @make_decorator(orig_test) def test_new(*args, **kwargs): @@ -103,10 +103,7 @@ def assert_raises_cudnn_disabled(assertion_error=False): if not cudnn_disabled or mx.context.current_context().device_type == 'cpu': orig_test(*args, **kwargs) else: - if assertion_error: - errors = (MXNetError, RuntimeError, AssertionError) - else: - errors = (MXNetError, RuntimeError) + errors = (MXNetError, RuntimeError) assert_raises(errors, orig_test, *args, **kwargs) return test_new return test_helper diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 54fa0a7..11180eb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2416,9 +2416,6 @@ def test_flip(): @with_seed() -# The test is disabled with USE_CUDA=ON and USE_CUDNN=OFF because of failures with the SpatialTransformer op. -# Tracked at https://github.com/apache/incubator-mxnet/issues/11568 -@assert_raises_cudnn_disabled(assertion_error=True) def test_stn(): np.set_printoptions(threshold=np.nan) num_filter = 2 # conv of loc net