This is an automated email from the ASF dual-hosted git repository. anirudh2290 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 4e19a32 remove flaky test and add consistency test for stable testing (#12427) 4e19a32 is described below commit 4e19a328ae94c893ed11591b798aaebf33f39052 Author: Hao Jin <haoj...@users.noreply.github.com> AuthorDate: Wed Sep 5 11:34:54 2018 -0700 remove flaky test and add consistency test for stable testing (#12427) --- src/operator/bilinear_sampler-inl.h | 3 + src/operator/bilinear_sampler.cu | 6 +- tests/python/gpu/test_operator_gpu.py | 59 ++++++++++++ tests/python/unittest/test_operator.py | 160 --------------------------------- 4 files changed, 67 insertions(+), 161 deletions(-) diff --git a/src/operator/bilinear_sampler-inl.h b/src/operator/bilinear_sampler-inl.h index 657aeba..e0b4db7 100644 --- a/src/operator/bilinear_sampler-inl.h +++ b/src/operator/bilinear_sampler-inl.h @@ -44,7 +44,10 @@ enum BilinearSamplerOpOutputs {kOut, kTmp}; } struct BilinearSamplerParam : public dmlc::Parameter<BilinearSamplerParam> { + dmlc::optional<bool> cudnn_off; DMLC_DECLARE_PARAMETER(BilinearSamplerParam) { + DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>()) + .describe("whether to turn cudnn off"); } }; diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu index 0ab628d..e1f2052 100644 --- a/src/operator/bilinear_sampler.cu +++ b/src/operator/bilinear_sampler.cu @@ -212,7 +212,11 @@ Operator* CreateOp<gpu>(BilinearSamplerParam param, int dtype) { Operator *op = NULL; #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new CuDNNBilinearSamplerOp<DType>(param); + if (param.cudnn_off.has_value() && param.cudnn_off.value()) { + op = new BilinearSamplerOp<gpu, DType>(param); + } else { + op = new CuDNNBilinearSamplerOp<DType>(param); + } }) #else MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 0ff33e1..7b75275 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1918,6 +1918,65 @@ def test_softmax_activation(): assert_almost_equal(cpu_a.grad.asnumpy(), gpu_a.grad.asnumpy(), atol = 1e-3, rtol = 1e-3) + +@with_seed() +def test_bilinear_sampler_versions(): + data = mx.sym.Variable('data') + grid = mx.sym.Variable('grid') + sym1 = mx.sym.BilinearSampler(data=data, grid=grid) + sym2 = mx.sym.BilinearSampler(data=data, grid=grid, cudnn_off=True) + sym3 = mx.sym.BilinearSampler(data=data, grid=grid) + + test_cases = [[(1,3,15,16),(1,2,10,10)], + [(1,6,7,16),(1,2,10,4)], + [(1,7,3,16),(1,2,8,11)], + [(1,9,50,50),(1,2,50,50)]] + + for item in test_cases: + data_shape, grid_shape = item + # kWriteTo + exe_cpu = sym1.simple_bind(data=data_shape, grid=grid_shape, ctx=mx.cpu(), grad_req='write') + exe_gpu = sym2.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req='write') + exe_cudnn = sym3.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req='write') + exe_list = [exe_cpu, exe_gpu, exe_cudnn] + ref_idx = 0 + test_data = np.random.uniform(low=-0.1, high=0.1,size=data_shape).astype(np.float32) + test_grid = np.random.uniform(low=-2, high=2, size=grid_shape).astype(np.float32) + for exe in exe_list: + exe.arg_dict['data'][:] = test_data + exe.arg_dict['grid'][:] = test_grid + exe.forward(is_train=True) + assert_almost_equal(exe_list[0].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5) + + out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32) + for exe in exe_list: + exe.backward(mx.nd.array(out_grad)) + assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) + + data_grad = exe_list[ref_idx].grad_dict['data'].asnumpy() + grid_grad = exe_list[ref_idx].grad_dict['grid'].asnumpy() + + # kAddTo + exe_cpu_addto = sym1.simple_bind(data=data_shape, grid=grid_shape, ctx=mx.cpu(), grad_req='add') + exe_gpu_addto = sym2.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req='add') + exe_cudnn_addto = sym3.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req='add') + exe_list = [exe_cpu_addto, exe_gpu_addto, exe_cudnn_addto] + data_initial_grad = np.random.normal(size=exe_list[ref_idx].grad_dict['data'].shape).astype(np.float32) + grid_initial_grad = np.random.normal(size=exe_list[ref_idx].grad_dict['grid'].shape).astype(np.float32) + for exe in exe_list: + exe.arg_dict['data'][:] = test_data + exe.arg_dict['grid'][:] = test_grid + exe.grad_dict['data'][:] = data_initial_grad + exe.grad_dict['grid'][:] = grid_initial_grad + exe.forward(is_train=True) + exe.backward(mx.nd.array(out_grad)) + assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe_list[ref_idx].grad_dict['data'].asnumpy(), data_grad + data_initial_grad, rtol=1e-3, atol=1e-5) + assert_almost_equal(exe_list[ref_idx].grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5) + + def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ca358ef..9842a69 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3945,166 +3945,6 @@ def test_grid_generator(): assert_almost_equal(exe_add.grad_dict['flow'].asnumpy(), grad_est + flow_grad_npy, rtol=1e-3, atol=1e-5) -@unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/12248") -def test_bilinear_sampler(): - from math import floor - - def between(x, lowerbound, upperbound): - return x>=lowerbound and x<=upperbound - - def bilinear_forward_numpy(data, grid): - - batchsize = data.shape[0] - input_height = data.shape[2] - input_width = data.shape[3] - num_channel = data.shape[1] - - output_height = grid.shape[2] - output_width = grid.shape[3] - out = np.zeros(data.shape[:2] + grid.shape[2:], dtype=np.float32) - - for i in range(batchsize): - for yout in range(output_height): - for xout in range(output_width): - - xcoord = np.float32((grid[i, 0, yout, xout] + 1) * (input_width-1) / 2.0) - ycoord = np.float32((grid[i, 1, yout, xout] + 1) * (input_height-1) / 2.0) - - xInTopLeft = int(floor(xcoord)) - xWeightTopLeft = np.float32(1-(xcoord - xInTopLeft)) - - yInTopLeft = int(floor(ycoord)) - yWeightTopLeft = np.float32(1-(ycoord - yInTopLeft)) - - # interpolation - for channel in range(num_channel): - - inTopLeft = data[i,channel,yInTopLeft, xInTopLeft] \ - if between(xInTopLeft,0,input_width-1) and between(yInTopLeft,0,input_height-1) else 0.0 - inTopRight = data[i,channel,yInTopLeft, xInTopLeft+1] \ - if between(xInTopLeft+1,0,input_width-1) and between(yInTopLeft,0,input_height-1) else 0.0 - inBottomLeft = data[i,channel,yInTopLeft+1, xInTopLeft] \ - if between(xInTopLeft,0,input_width-1) and between(yInTopLeft+1,0,input_height-1) else 0.0 - inBottomRight = data[i,channel,yInTopLeft+1, xInTopLeft+1] \ - if between(xInTopLeft+1,0,input_width-1) and between(yInTopLeft+1,0,input_height-1) else 0.0 - - out[i,channel,yout,xout] = xWeightTopLeft * yWeightTopLeft * inTopLeft\ - + (1-xWeightTopLeft)*yWeightTopLeft * inTopRight\ - + xWeightTopLeft * (1-yWeightTopLeft) * inBottomLeft\ - +(1-xWeightTopLeft) * (1-yWeightTopLeft) * inBottomRight - return out - - def bilinear_backward_numpy(out_grad, data, grid): - - data_grad = np.zeros(data.shape, dtype=np.float32) - grid_grad = np.zeros(grid.shape, dtype=np.float32) - - batchsize = data.shape[0] - input_height = data.shape[2] - input_width = data.shape[3] - num_channel = data.shape[1] - output_height = grid.shape[2] - output_width = grid.shape[3] - - for i in range(batchsize): - for yout in range(output_height): - for xout in range(output_width): - - top_left_y_gw = np.float32(0.0); - top_left_x_gw = np.float32(0.0); - - xcoord = np.float32((grid[i, 0, yout, xout] + 1) * (input_width-1) / 2.0) - ycoord = np.float32((grid[i, 1, yout, xout] + 1) * (input_height-1) / 2.0) - - xInTopLeft = int(floor(xcoord)) - xWeightTopLeft = np.float32(1-(xcoord - xInTopLeft)) - - yInTopLeft = int(floor(ycoord)) - yWeightTopLeft = np.float32(1-(ycoord - yInTopLeft)) - - topLeftDotProduct = np.float32(0) - topRightDotProduct = np.float32(0) - bottomLeftDotProduct = np.float32(0) - bottomRightDotProduct = np.float32(0) - - for channel in range(num_channel): - # left top - if between(xInTopLeft,0,input_width-1) and between(yInTopLeft,0,input_height-1): - topLeftDotProduct += data[i,channel,yInTopLeft, xInTopLeft] * \ - out_grad[i,channel,yout,xout] - data_grad[i, channel, yInTopLeft, xInTopLeft] += xWeightTopLeft * \ - yWeightTopLeft * out_grad[i,channel,yout,xout] - # right top - if between(xInTopLeft+1,0,input_width-1) and between(yInTopLeft,0,input_height-1): - topRightDotProduct += data[i, channel, yInTopLeft,xInTopLeft+1] * \ - out_grad[i, channel, yout,xout] - data_grad[i, channel,yInTopLeft, xInTopLeft+1] += (1-xWeightTopLeft) * \ - yWeightTopLeft * out_grad[i,channel,yout,xout] - # left bottom - if between(xInTopLeft,0,input_width-1) and between(yInTopLeft+1,0,input_height-1): - bottomLeftDotProduct += data[i, channel,yInTopLeft+1, xInTopLeft] * \ - out_grad[i,channel,yout,xout] - data_grad[i,channel,yInTopLeft+1,xInTopLeft]+=xWeightTopLeft * \ - (1-yWeightTopLeft)* out_grad[i,channel,yout,xout] - # right bottom - if between(xInTopLeft+1,0,input_width-1) and between(yInTopLeft+1,0,input_height-1): - bottomRightDotProduct += data[i,channel,yInTopLeft+1, xInTopLeft+1] * \ - out_grad[i,channel,yout,xout] - data_grad[i,channel,yInTopLeft+1,xInTopLeft+1]+= (1-xWeightTopLeft) * \ - (1-yWeightTopLeft)*out_grad[i,channel,yout,xout] - - yf = np.float32(-xWeightTopLeft * topLeftDotProduct + xWeightTopLeft*bottomLeftDotProduct - \ - (1-xWeightTopLeft)* topRightDotProduct + (1-xWeightTopLeft)*bottomRightDotProduct) - xf = np.float32(-yWeightTopLeft * topLeftDotProduct + yWeightTopLeft*topRightDotProduct - \ - (1-yWeightTopLeft)*bottomLeftDotProduct + (1-yWeightTopLeft)*bottomRightDotProduct) - - grid_grad[i,0,yout,xout] = xf * (input_width-1) / 2.0 - grid_grad[i,1,yout,xout] = yf * (input_height-1) / 2.0 - - return data_grad, grid_grad - - data = mx.sym.Variable('data') - grid = mx.sym.Variable('grid') - net = mx.sym.BilinearSampler(data=data,grid=grid) - - test_case = [[(1,3,15,16),(1,2,10,10)], - [(1,6,7,16),(1,2,10,4)], - [(1,7,3,16),(1,2,8,11)], - [(1,9,50,50),(1,2,50,50)]] - - for ctx in [default_context()]: - for item in test_case: - data_shape, grid_shape = item - exe = net.simple_bind(data=data_shape,grid=grid_shape,ctx=ctx,grad_req='write') - # check forward - exe.arg_dict['data'][:] = np.random.uniform(low=-0.1, high=0.1,size=data_shape).astype(np.float32) - exe.arg_dict['grid'][:] = np.random.uniform(low=-2, high=2, size=grid_shape).astype(np.float32) - exe.forward(is_train=True) - out = bilinear_forward_numpy(exe.arg_dict['data'].asnumpy(), exe.arg_dict['grid'].asnumpy()) - assert_almost_equal(exe.outputs[0].asnumpy(), out, rtol=1e-3,atol=1e-5) - - # check backward - out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32) - exe.backward(mx.nd.array(out_grad)) - data_grad, grid_grad = bilinear_backward_numpy(out_grad,exe.arg_dict['data'].asnumpy(), - exe.arg_dict['grid'].asnumpy()) - assert_almost_equal(exe.grad_dict['data'].asnumpy(), data_grad, rtol=1e-3, atol=1e-5) - assert_almost_equal(exe.grad_dict['grid'].asnumpy(), grid_grad, rtol=1e-3, atol=1e-5) - - # check kAddTo - exe_addto = net.simple_bind(data=data_shape, grid=grid_shape, ctx=ctx, grad_req='add') - data_initial_grid = np.random.normal(size=exe_addto.grad_dict['data'].shape).astype(np.float32) - grid_initial_grid = np.random.normal(size=exe_addto.grad_dict['grid'].shape).astype(np.float32) - exe_addto.arg_dict['data'][:] = exe.arg_dict['data'][:] - exe_addto.arg_dict['grid'][:] = exe.arg_dict['grid'][:] - exe_addto.grad_dict['data'][:] = data_initial_grid - exe_addto.grad_dict['grid'][:] = grid_initial_grid - exe_addto.forward(is_train=True) - exe_addto.backward(mx.nd.array(out_grad)) - assert_almost_equal(exe_addto.grad_dict['data'].asnumpy(), data_grad + data_initial_grid, rtol=1e-3,atol=1e-5) - assert_almost_equal(exe_addto.grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grid, rtol=1e-3,atol=1e-5) - - @with_seed() def test_index2d(): for _ in range(30):