This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new f9a97ca Add more gpu tests and fix nested try catch in Imperative Invoke (#7676) f9a97ca is described below commit f9a97ca8a523ec3ff074f8113f7fed11f3f0769d Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Wed Aug 30 23:12:06 2017 -0700 Add more gpu tests and fix nested try catch in Imperative Invoke (#7676) * add more test in test_gpu * remove a density for cast_storage since its slow * fix nested try catch in imperative invoke * import * from test_ndarary * add inline * fix wrong args in pick --- src/c_api/c_api_ndarray.cc | 34 ++- tests/python/gpu/test_operator_gpu.py | 4 +- tests/python/unittest/test_ndarray.py | 4 +- tests/python/unittest/test_sparse_operator.py | 337 +++++++++++++------------- 4 files changed, 198 insertions(+), 181 deletions(-) diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 64fa74d..5d3fd70 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -550,19 +550,18 @@ void ImperativeInvokeImpl(const Context& default_ctx, } } -int MXImperativeInvoke(AtomicSymbolCreator creator, - int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs, - int num_params, - const char **param_keys, - const char **param_vals) { +inline void MXImperativeInvokeImpl(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals) { const nnvm::Op* op = static_cast<nnvm::Op*>(creator); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); NDArray** outarray = *reinterpret_cast<NDArray***>(outputs); - API_BEGIN(); nnvm::NodeAttrs attrs; SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals); @@ -588,6 +587,19 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, *outarray[i] = std::move(ndoutputs[i]); } } +} + +int MXImperativeInvoke(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals) { + API_BEGIN(); + MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, + outputs, num_params, param_keys, param_vals); API_END(); } @@ -601,8 +613,8 @@ int MXImperativeInvokeEx(AtomicSymbolCreator creator, const char **param_vals, const int **out_stypes) { // outputs storage types API_BEGIN(); - MXImperativeInvoke(creator, num_inputs, inputs, num_outputs, outputs, - num_params, param_keys, param_vals); + MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, outputs, + num_params, param_keys, param_vals); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); NDArray** output_nds = reinterpret_cast<NDArray**>(*outputs); ret->out_types.resize(*num_outputs); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f9845f9..73a1f95 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -33,9 +33,9 @@ from test_gluon import * from test_loss import * #from test_rnn import * from test_gluon_rnn import * -from test_sparse_operator import test_cast_storage_ex, test_sparse_dot -from test_sparse_operator import test_sparse_nd_zeros, test_sparse_retain from test_sparse_ndarray import test_create_csr, test_create_row_sparse +from test_sparse_operator import * +from test_ndarray import * set_default_context(mx.gpu(0)) del test_support_vector_machine_l1_svm diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 7d11dbe..7cb6891 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -705,7 +705,7 @@ def test_ndarray_fluent(): 'clip', 'abs' 'sign']) def check_fluent_regular(func, kwargs, shape=(5, 17, 1)): with mx.name.NameManager(): - data = mx.nd.random_uniform(shape=shape) + data = mx.nd.random_uniform(shape=shape, ctx=default_context()) regular = getattr(mx.ndarray, func)(data, **kwargs) fluent = getattr(data, func)(**kwargs) if isinstance(regular, list): @@ -729,7 +729,7 @@ def test_ndarray_fluent(): check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, shape=(5, 17, 6)) check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7}) check_fluent_regular('take', {'indices': mx.nd.array([2, 3])}) - check_fluent_regular('pick', {'axis': 1, 'begin': 5, 'end': 7}) + check_fluent_regular('pick', {'axis': 1, 'index': mx.nd.array([[2], [3], [5], [6], [11]])}) check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75}) check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)}) check_fluent_regular('pad', {'mode': 'constant', 'pad_width': (0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3)) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 2875d7b..a4c8342 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -41,48 +41,50 @@ def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_ def test_elemwise_add_ex(): - shapes = [rand_shape_2d(), rand_shape_3d()] - for shape in shapes: - check_elemwise_add_ex('default', 'default', shape) - check_elemwise_add_ex('default', 'row_sparse', shape) - check_elemwise_add_ex('row_sparse', 'default', shape) - check_elemwise_add_ex('row_sparse', 'row_sparse', shape, - lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + if default_context().device_type == 'cpu': + shapes = [rand_shape_2d(), rand_shape_3d()] + for shape in shapes: + check_elemwise_add_ex('default', 'default', shape) + check_elemwise_add_ex('default', 'row_sparse', shape) + check_elemwise_add_ex('row_sparse', 'default', shape) + check_elemwise_add_ex('row_sparse', 'row_sparse', shape, + lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') # TODO(haibin) randomize this test def test_elemwise_add_ex_multiple_stages(): - # prep data - shape = (4, 2) - ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) - sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) - - val1 = mx.nd.array([[5, 10]]); - val2 = mx.nd.array([[5, 10]]); - idx1 = mx.nd.array([0], dtype=np.int64); - idx2 = mx.nd.array([1], dtype=np.int64); - sp_nd1 = mx.nd.sparse.row_sparse_array(val1, idx1, shape) - sp_nd2 = mx.nd.sparse.row_sparse_array(val2, idx2, shape) - ds_nd = mx.nd.array(ds_np) - - # sparse + sparse = sparse - sp_data1 = mx.symbol.Variable('sp_data1', stype='row_sparse') - sp_data2 = mx.symbol.Variable('sp_data2', stype='row_sparse') - ds_data = mx.symbol.Variable('ds_data') - plus = mx.symbol.sparse.elemwise_add(sp_data1, sp_data2, name='plus') - # sparse + dense = dense - test = mx.symbol.sparse.elemwise_add(plus, ds_data) - check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, - 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) - - arr_grads = [mx.nd.zeros(shape) for i in range(3)] - exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, - 'ds_data': ds_nd}, args_grad=arr_grads) - exec_test.forward(is_train=True) - assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) - exec_test.backward(out_grads=exec_test.outputs) - assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + if default_context().device_type == 'cpu': + # prep data + shape = (4, 2) + ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) + sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) + + val1 = mx.nd.array([[5, 10]]); + val2 = mx.nd.array([[5, 10]]); + idx1 = mx.nd.array([0], dtype=np.int64); + idx2 = mx.nd.array([1], dtype=np.int64); + sp_nd1 = mx.nd.sparse.row_sparse_array(val1, idx1, shape) + sp_nd2 = mx.nd.sparse.row_sparse_array(val2, idx2, shape) + ds_nd = mx.nd.array(ds_np) + + # sparse + sparse = sparse + sp_data1 = mx.symbol.Variable('sp_data1', stype='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', stype='row_sparse') + ds_data = mx.symbol.Variable('ds_data') + plus = mx.symbol.sparse.elemwise_add(sp_data1, sp_data2, name='plus') + # sparse + dense = dense + test = mx.symbol.sparse.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) + + arr_grads = [mx.nd.zeros(shape) for i in range(3)] + exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, args_grad=arr_grads) + exec_test.forward(is_train=True) + assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) + exec_test.backward(out_grads=exec_test.outputs) + assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) def test_cast_storage_ex(): def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad=True): @@ -99,7 +101,7 @@ def test_cast_storage_ex(): grad_stypes = {'x': to_stype} check_symbolic_backward(test, location, [out_np], [out_np], grad_stypes=grad_stypes) - density = [1.00, 0.50, 0.10, 0.05, 0.01] + density = [1.00, 0.50, 0.05, 0.01] for d in density: shape_2d = rand_shape_2d() shape_3d = rand_shape_3d() @@ -231,141 +233,144 @@ def test_sparse_nd_zeros(): def test_sparse_square_sum(): - dim0 = 30 - dim1 = 30 - axes = [0, 1] - keepdims = [False, True] - densities = [0, 0.01, 0.1, 0.2, 0.5] - for density in densities: - shape = rand_shape_2d(dim0, dim1) - rsp = rand_ndarray(shape, 'row_sparse', density) - dns = rsp.tostype('default') - for axis in axes: - for keepdim in keepdims: - ret = mx.nd._internal._square_sum(rsp, axis=axis, keepdims=keepdim) - if axis == 1 and keepdim: - assert ret.stype == 'row_sparse' - else: - assert ret.stype == 'default' - ret_expected = mx.nd.sum(dns*dns, axis=axis, keepdims=keepdim) - # check forward result - assert same(ret.asnumpy(), ret_expected.asnumpy()) - - rsp_data = mx.sym.Variable('data', stype='row_sparse') - test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim) - - # check symbolic backward since ograd can be a rsp - # and cannot be checked through check_numeric_gradient - # because it will add a loss layer as the output layer - # which makes ograd of the square_sum dense - if axis == 1 and keepdims: - dns_data = mx.sym.Variable('data') - baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim) - igrad_expected = mx.nd.empty(dns.shape) - baseline_exec = baseline.bind(default_context(), args=[dns], - args_grad=[igrad_expected]) - baseline_exec.forward(is_train=True) - baseline_exec.backward([ret_expected]) - check_symbolic_backward(test, [rsp], [ret], [igrad_expected.asnumpy()], - grad_stypes={'data': 'row_sparse'}) - - # check numeric gradient - check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'}, - atol=1e-2, rtol=0.1) + if default_context().device_type == 'cpu': + dim0 = 30 + dim1 = 30 + axes = [0, 1] + keepdims = [False, True] + densities = [0, 0.01, 0.1, 0.2, 0.5] + for density in densities: + shape = rand_shape_2d(dim0, dim1) + rsp = rand_ndarray(shape, 'row_sparse', density) + dns = rsp.tostype('default') + for axis in axes: + for keepdim in keepdims: + ret = mx.nd._internal._square_sum(rsp, axis=axis, keepdims=keepdim) + if axis == 1 and keepdim: + assert ret.stype == 'row_sparse' + else: + assert ret.stype == 'default' + ret_expected = mx.nd.sum(dns*dns, axis=axis, keepdims=keepdim) + # check forward result + assert same(ret.asnumpy(), ret_expected.asnumpy()) + + rsp_data = mx.sym.Variable('data', stype='row_sparse') + test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim) + + # check symbolic backward since ograd can be a rsp + # and cannot be checked through check_numeric_gradient + # because it will add a loss layer as the output layer + # which makes ograd of the square_sum dense + if axis == 1 and keepdims: + dns_data = mx.sym.Variable('data') + baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim) + igrad_expected = mx.nd.empty(dns.shape) + baseline_exec = baseline.bind(default_context(), args=[dns], + args_grad=[igrad_expected]) + baseline_exec.forward(is_train=True) + baseline_exec.backward([ret_expected]) + check_symbolic_backward(test, [rsp], [ret], [igrad_expected.asnumpy()], + grad_stypes={'data': 'row_sparse'}) + + # check numeric gradient + check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'}, + atol=1e-2, rtol=0.1) def test_sparse_storage_fallback(): """ test operators which don't implement FComputeEx or FStatefulComputeEx """ - def check_broadcast_add(shape, lhs_stype, rhs_stype): - lhs = mx.symbol.Variable('lhs', stype=lhs_stype) - rhs = mx.symbol.Variable('rhs', stype=rhs_stype) - lhs_nd = rand_ndarray(shape, lhs_stype) - rhs_nd = rand_ndarray(shape, rhs_stype) - lhs_dns = mx.nd.cast_storage(lhs_nd, stype='default') - rhs_dns = mx.nd.cast_storage(rhs_nd, stype='default') + if default_context().device_type == 'cpu': + def check_broadcast_add(shape, lhs_stype, rhs_stype): + lhs = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) + lhs_nd = rand_ndarray(shape, lhs_stype) + rhs_nd = rand_ndarray(shape, rhs_stype) + lhs_dns = mx.nd.cast_storage(lhs_nd, stype='default') + rhs_dns = mx.nd.cast_storage(rhs_nd, stype='default') + + out_dns = (lhs_dns + rhs_dns).asnumpy() + test = mx.symbol.broadcast_add(lhs, rhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + check_symbolic_forward(test, location, [out_dns]) + check_numeric_gradient(test, location) + check_symbolic_backward(test, location, [out_dns], [out_dns, out_dns]) + + def np_softmax(x, axis=-1): + # fix for old numpy on Travis not supporting keepdims + # x = x - np.max(x, axis=-1, keepdims=True) + x = x - np.max(x, axis=axis, keepdims=True) + x = np.exp(x) + # x /= np.sum(x, axis=-1, keepdims=True) + x /= np.sum(x, axis=axis, keepdims=True) + return x + + def check_softmax_with_shape(lhs_stype, rhs_stype, shape, preserve_shape=False): + # bind with label + ctx = default_context() + X = mx.symbol.Variable('X', stype=lhs_stype) + L = mx.symbol.Variable('L', stype=rhs_stype) + Y = mx.symbol.SoftmaxOutput(data=X, label=L, preserve_shape=preserve_shape) + x = rand_ndarray(shape, lhs_stype) + l = rand_ndarray(shape, rhs_stype) + l[:] = np_softmax(l.asnumpy()) + grad = mx.nd.empty(shape, ctx=ctx) + exec1 = Y.bind(ctx, args = [x, l], args_grad = {'X': grad}) + exec1.forward(is_train=True) + out = exec1.outputs[0].asnumpy() + assert_almost_equal(out, np_softmax(x.asnumpy()), rtol=1e-4) + exec1.backward() + assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy(), + rtol=1e-3, atol=1e-4) - out_dns = (lhs_dns + rhs_dns).asnumpy() - test = mx.symbol.broadcast_add(lhs, rhs) - location = {'lhs': lhs_nd, 'rhs': rhs_nd} - check_symbolic_forward(test, location, [out_dns]) - check_numeric_gradient(test, location) - check_symbolic_backward(test, location, [out_dns], [out_dns, out_dns]) - - def np_softmax(x, axis=-1): - # fix for old numpy on Travis not supporting keepdims - # x = x - np.max(x, axis=-1, keepdims=True) - x = x - np.max(x, axis=axis, keepdims=True) - x = np.exp(x) - # x /= np.sum(x, axis=-1, keepdims=True) - x /= np.sum(x, axis=axis, keepdims=True) - return x - - def check_softmax_with_shape(lhs_stype, rhs_stype, shape, preserve_shape=False): - # bind with label - ctx = default_context() - X = mx.symbol.Variable('X', stype=lhs_stype) - L = mx.symbol.Variable('L', stype=rhs_stype) - Y = mx.symbol.SoftmaxOutput(data=X, label=L, preserve_shape=preserve_shape) - x = rand_ndarray(shape, lhs_stype) - l = rand_ndarray(shape, rhs_stype) - l[:] = np_softmax(l.asnumpy()) - grad = mx.nd.empty(shape, ctx=ctx) - exec1 = Y.bind(ctx, args = [x, l], args_grad = {'X': grad}) - exec1.forward(is_train=True) - out = exec1.outputs[0].asnumpy() - assert_almost_equal(out, np_softmax(x.asnumpy()), rtol=1e-4) - exec1.backward() - assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy(), - rtol=1e-3, atol=1e-4) - - def check_concat(shape, lhs_stype, rhs_stype): - x = mx.symbol.Variable('x', stype=lhs_stype) - w = mx.symbol.Variable('w', stype=rhs_stype) - test = mx.sym.Concat(x, w) - x_nd = rand_ndarray(shape, lhs_stype) - w_nd = rand_ndarray(shape, rhs_stype) - location = {'x': x_nd, 'w': w_nd} - check_numeric_gradient(test, location) + def check_concat(shape, lhs_stype, rhs_stype): + x = mx.symbol.Variable('x', stype=lhs_stype) + w = mx.symbol.Variable('w', stype=rhs_stype) + test = mx.sym.Concat(x, w) + x_nd = rand_ndarray(shape, lhs_stype) + w_nd = rand_ndarray(shape, rhs_stype) + location = {'x': x_nd, 'w': w_nd} + check_numeric_gradient(test, location) - shape = rand_shape_2d() - stypes = ['default', 'csr', 'row_sparse'] - for lhs in stypes: - for rhs in stypes: - check_broadcast_add(shape, lhs, rhs) - check_concat(shape, lhs, rhs) - check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False) - check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True) + shape = rand_shape_2d() + stypes = ['default', 'csr', 'row_sparse'] + for lhs in stypes: + for rhs in stypes: + check_broadcast_add(shape, lhs, rhs) + check_concat(shape, lhs, rhs) + check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False) + check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True) def test_sparse_elementwise_sum(): - def check_sparse_elementwise_sum_with_shape(stype, shape, n): - # forward - inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] - out = mx.symbol.sparse.add_n(*inputs, name='esum') - arr = [] - arr_grad = [mx.nd.empty(shape) for _ in range(n)] - densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5] - for i in range(n): - arr.append(rand_ndarray(shape, stype, np.random.randint(0, len(densities)))) - - exec1 = out.bind(default_context(), - args=arr, - args_grad=arr_grad) - exec1.forward(is_train=True) - out1 = exec1.outputs[0].asnumpy() - out = sum(a.asnumpy() for a in arr) - assert_almost_equal(out, out1) - - out_grad = mx.nd.empty(shape) - out_grad[:] = np.random.uniform(-10, 10, shape) - # backward - exec1.backward([out_grad]) - for a in arr_grad: - assert_almost_equal(a.asnumpy(), out_grad.asnumpy()) - - maxdim = 5 - for dim in range(2, maxdim): - shape = tuple(np.random.randint(5, 10, size=dim)) - check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9)) + if default_context().device_type == 'cpu': + def check_sparse_elementwise_sum_with_shape(stype, shape, n): + # forward + inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] + out = mx.symbol.sparse.add_n(*inputs, name='esum') + arr = [] + arr_grad = [mx.nd.empty(shape) for _ in range(n)] + densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5] + for i in range(n): + arr.append(rand_ndarray(shape, stype, np.random.randint(0, len(densities)))) + + exec1 = out.bind(default_context(), + args=arr, + args_grad=arr_grad) + exec1.forward(is_train=True) + out1 = exec1.outputs[0].asnumpy() + out = sum(a.asnumpy() for a in arr) + assert_almost_equal(out, out1) + + out_grad = mx.nd.empty(shape) + out_grad[:] = np.random.uniform(-10, 10, shape) + # backward + exec1.backward([out_grad]) + for a in arr_grad: + assert_almost_equal(a.asnumpy(), out_grad.asnumpy()) + + maxdim = 5 + for dim in range(2, maxdim): + shape = tuple(np.random.randint(5, 10, size=dim)) + check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9)) if __name__ == '__main__': -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].