This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f9a97ca  Add more gpu tests and fix nested try catch in Imperative 
Invoke  (#7676)
f9a97ca is described below

commit f9a97ca8a523ec3ff074f8113f7fed11f3f0769d
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Wed Aug 30 23:12:06 2017 -0700

    Add more gpu tests and fix nested try catch in Imperative Invoke  (#7676)
    
    * add more test in test_gpu
    
    * remove a density for cast_storage since its slow
    
    * fix nested try catch in imperative invoke
    
    * import * from test_ndarary
    
    * add inline
    
    * fix wrong args in pick
---
 src/c_api/c_api_ndarray.cc                    |  34 ++-
 tests/python/gpu/test_operator_gpu.py         |   4 +-
 tests/python/unittest/test_ndarray.py         |   4 +-
 tests/python/unittest/test_sparse_operator.py | 337 +++++++++++++-------------
 4 files changed, 198 insertions(+), 181 deletions(-)

diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 64fa74d..5d3fd70 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -550,19 +550,18 @@ void ImperativeInvokeImpl(const Context& default_ctx,
   }
 }
 
-int MXImperativeInvoke(AtomicSymbolCreator creator,
-                       int num_inputs,
-                       NDArrayHandle *inputs,
-                       int *num_outputs,
-                       NDArrayHandle **outputs,
-                       int num_params,
-                       const char **param_keys,
-                       const char **param_vals) {
+inline void MXImperativeInvokeImpl(AtomicSymbolCreator creator,
+                                   int num_inputs,
+                                   NDArrayHandle *inputs,
+                                   int *num_outputs,
+                                   NDArrayHandle **outputs,
+                                   int num_params,
+                                   const char **param_keys,
+                                   const char **param_vals) {
   const nnvm::Op* op = static_cast<nnvm::Op*>(creator);
   MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
   NDArray** outarray = *reinterpret_cast<NDArray***>(outputs);
 
-  API_BEGIN();
   nnvm::NodeAttrs attrs;
   SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals);
 
@@ -588,6 +587,19 @@ int MXImperativeInvoke(AtomicSymbolCreator creator,
       *outarray[i] = std::move(ndoutputs[i]);
     }
   }
+}
+
+int MXImperativeInvoke(AtomicSymbolCreator creator,
+                       int num_inputs,
+                       NDArrayHandle *inputs,
+                       int *num_outputs,
+                       NDArrayHandle **outputs,
+                       int num_params,
+                       const char **param_keys,
+                       const char **param_vals) {
+  API_BEGIN();
+  MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs,
+                         outputs, num_params, param_keys, param_vals);
   API_END();
 }
 
@@ -601,8 +613,8 @@ int MXImperativeInvokeEx(AtomicSymbolCreator creator,
                          const char **param_vals,
                          const int **out_stypes) {  // outputs storage types
   API_BEGIN();
-  MXImperativeInvoke(creator, num_inputs, inputs, num_outputs, outputs,
-                     num_params, param_keys, param_vals);
+  MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, outputs,
+                         num_params, param_keys, param_vals);
   MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
   NDArray** output_nds = reinterpret_cast<NDArray**>(*outputs);
   ret->out_types.resize(*num_outputs);
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index f9845f9..73a1f95 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -33,9 +33,9 @@ from test_gluon import *
 from test_loss import *
 #from test_rnn import *
 from test_gluon_rnn import *
-from test_sparse_operator import test_cast_storage_ex, test_sparse_dot
-from test_sparse_operator import test_sparse_nd_zeros, test_sparse_retain
 from test_sparse_ndarray import test_create_csr, test_create_row_sparse
+from test_sparse_operator import *
+from test_ndarray import *
 
 set_default_context(mx.gpu(0))
 del test_support_vector_machine_l1_svm
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 7d11dbe..7cb6891 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -705,7 +705,7 @@ def test_ndarray_fluent():
                     'clip', 'abs' 'sign'])
     def check_fluent_regular(func, kwargs, shape=(5, 17, 1)):
         with mx.name.NameManager():
-            data = mx.nd.random_uniform(shape=shape)
+            data = mx.nd.random_uniform(shape=shape, ctx=default_context())
             regular = getattr(mx.ndarray, func)(data, **kwargs)
             fluent = getattr(data, func)(**kwargs)
             if isinstance(regular, list):
@@ -729,7 +729,7 @@ def test_ndarray_fluent():
     check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, 
shape=(5, 17, 6))
     check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7})
     check_fluent_regular('take', {'indices': mx.nd.array([2, 3])})
-    check_fluent_regular('pick', {'axis': 1, 'begin': 5, 'end': 7})
+    check_fluent_regular('pick', {'axis': 1, 'index': mx.nd.array([[2], [3], 
[5], [6], [11]])})
     check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
     check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)})
     check_fluent_regular('pad', {'mode': 'constant', 'pad_width': 
(0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3))
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 2875d7b..a4c8342 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -41,48 +41,50 @@ def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, 
lhs_grad_stype=None, rhs_
 
 
 def test_elemwise_add_ex():
-    shapes = [rand_shape_2d(), rand_shape_3d()]
-    for shape in shapes:
-        check_elemwise_add_ex('default', 'default', shape)
-        check_elemwise_add_ex('default', 'row_sparse', shape)
-        check_elemwise_add_ex('row_sparse', 'default', shape)
-        check_elemwise_add_ex('row_sparse', 'row_sparse', shape,
-                              lhs_grad_stype='row_sparse', 
rhs_grad_stype='row_sparse')
+    if default_context().device_type == 'cpu':
+        shapes = [rand_shape_2d(), rand_shape_3d()]
+        for shape in shapes:
+            check_elemwise_add_ex('default', 'default', shape)
+            check_elemwise_add_ex('default', 'row_sparse', shape)
+            check_elemwise_add_ex('row_sparse', 'default', shape)
+            check_elemwise_add_ex('row_sparse', 'row_sparse', shape,
+                                  lhs_grad_stype='row_sparse', 
rhs_grad_stype='row_sparse')
 
 
 # TODO(haibin) randomize this test
 def test_elemwise_add_ex_multiple_stages():
-    # prep data
-    shape = (4, 2)
-    ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
-    sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]])
-    sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]])
-
-    val1 = mx.nd.array([[5, 10]]);
-    val2 = mx.nd.array([[5, 10]]);
-    idx1 = mx.nd.array([0], dtype=np.int64);
-    idx2 = mx.nd.array([1], dtype=np.int64);
-    sp_nd1 = mx.nd.sparse.row_sparse_array(val1, idx1, shape)
-    sp_nd2 = mx.nd.sparse.row_sparse_array(val2, idx2, shape)
-    ds_nd = mx.nd.array(ds_np)
-
-    # sparse + sparse = sparse
-    sp_data1 = mx.symbol.Variable('sp_data1', stype='row_sparse')
-    sp_data2 = mx.symbol.Variable('sp_data2', stype='row_sparse')
-    ds_data = mx.symbol.Variable('ds_data')
-    plus = mx.symbol.sparse.elemwise_add(sp_data1, sp_data2, name='plus')
-    # sparse + dense = dense
-    test = mx.symbol.sparse.elemwise_add(plus, ds_data)
-    check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2,
-                                  'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np])
-
-    arr_grads = [mx.nd.zeros(shape) for i in range(3)]
-    exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 
'sp_data2': sp_nd2,
-                                                   'ds_data': ds_nd}, 
args_grad=arr_grads)
-    exec_test.forward(is_train=True)
-    assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + 
ds_np)
-    exec_test.backward(out_grads=exec_test.outputs)
-    assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy())
+    if default_context().device_type == 'cpu':
+        # prep data
+        shape = (4, 2)
+        ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
+        sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]])
+        sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]])
+
+        val1 = mx.nd.array([[5, 10]]);
+        val2 = mx.nd.array([[5, 10]]);
+        idx1 = mx.nd.array([0], dtype=np.int64);
+        idx2 = mx.nd.array([1], dtype=np.int64);
+        sp_nd1 = mx.nd.sparse.row_sparse_array(val1, idx1, shape)
+        sp_nd2 = mx.nd.sparse.row_sparse_array(val2, idx2, shape)
+        ds_nd = mx.nd.array(ds_np)
+
+        # sparse + sparse = sparse
+        sp_data1 = mx.symbol.Variable('sp_data1', stype='row_sparse')
+        sp_data2 = mx.symbol.Variable('sp_data2', stype='row_sparse')
+        ds_data = mx.symbol.Variable('ds_data')
+        plus = mx.symbol.sparse.elemwise_add(sp_data1, sp_data2, name='plus')
+        # sparse + dense = dense
+        test = mx.symbol.sparse.elemwise_add(plus, ds_data)
+        check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2,
+                                      'ds_data': ds_nd}, [sp_np1 + sp_np2 + 
ds_np])
+
+        arr_grads = [mx.nd.zeros(shape) for i in range(3)]
+        exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 
'sp_data2': sp_nd2,
+                                                       'ds_data': ds_nd}, 
args_grad=arr_grads)
+        exec_test.forward(is_train=True)
+        assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + 
ds_np)
+        exec_test.backward(out_grads=exec_test.outputs)
+        assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy())
 
 def test_cast_storage_ex():
     def check_cast_storage(shape, density, from_stype, to_stype, 
check_numeric_grad=True):
@@ -99,7 +101,7 @@ def test_cast_storage_ex():
         grad_stypes = {'x': to_stype}
         check_symbolic_backward(test, location, [out_np], [out_np], 
grad_stypes=grad_stypes)
 
-    density = [1.00, 0.50, 0.10, 0.05, 0.01]
+    density = [1.00, 0.50, 0.05, 0.01]
     for d in density:
         shape_2d = rand_shape_2d()
         shape_3d = rand_shape_3d()
@@ -231,141 +233,144 @@ def test_sparse_nd_zeros():
 
 
 def test_sparse_square_sum():
-    dim0 = 30
-    dim1 = 30
-    axes = [0, 1]
-    keepdims = [False, True]
-    densities = [0, 0.01, 0.1, 0.2, 0.5]
-    for density in densities:
-        shape = rand_shape_2d(dim0, dim1)
-        rsp = rand_ndarray(shape, 'row_sparse', density)
-        dns = rsp.tostype('default')
-        for axis in axes:
-            for keepdim in keepdims:
-                ret = mx.nd._internal._square_sum(rsp, axis=axis, 
keepdims=keepdim)
-                if axis == 1 and keepdim:
-                    assert ret.stype == 'row_sparse'
-                else:
-                    assert ret.stype == 'default'
-                ret_expected = mx.nd.sum(dns*dns, axis=axis, keepdims=keepdim)
-                # check forward result
-                assert same(ret.asnumpy(), ret_expected.asnumpy())
-
-                rsp_data = mx.sym.Variable('data', stype='row_sparse')
-                test = mx.symbol._internal._square_sum(rsp_data, axis=axis, 
keepdims=keepdim)
-
-                # check symbolic backward since ograd can be a rsp
-                # and cannot be checked through check_numeric_gradient
-                # because it will add a loss layer as the output layer
-                # which makes ograd of the square_sum dense
-                if axis == 1 and keepdims:
-                    dns_data = mx.sym.Variable('data')
-                    baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, 
keepdims=keepdim)
-                    igrad_expected = mx.nd.empty(dns.shape)
-                    baseline_exec = baseline.bind(default_context(), 
args=[dns],
-                                                  args_grad=[igrad_expected])
-                    baseline_exec.forward(is_train=True)
-                    baseline_exec.backward([ret_expected])
-                    check_symbolic_backward(test, [rsp], [ret], 
[igrad_expected.asnumpy()],
-                                            grad_stypes={'data': 'row_sparse'})
-
-                # check numeric gradient
-                check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 
'row_sparse'},
-                                       atol=1e-2, rtol=0.1)
+    if default_context().device_type == 'cpu':
+        dim0 = 30
+        dim1 = 30
+        axes = [0, 1]
+        keepdims = [False, True]
+        densities = [0, 0.01, 0.1, 0.2, 0.5]
+        for density in densities:
+            shape = rand_shape_2d(dim0, dim1)
+            rsp = rand_ndarray(shape, 'row_sparse', density)
+            dns = rsp.tostype('default')
+            for axis in axes:
+                for keepdim in keepdims:
+                    ret = mx.nd._internal._square_sum(rsp, axis=axis, 
keepdims=keepdim)
+                    if axis == 1 and keepdim:
+                        assert ret.stype == 'row_sparse'
+                    else:
+                        assert ret.stype == 'default'
+                    ret_expected = mx.nd.sum(dns*dns, axis=axis, 
keepdims=keepdim)
+                    # check forward result
+                    assert same(ret.asnumpy(), ret_expected.asnumpy())
+
+                    rsp_data = mx.sym.Variable('data', stype='row_sparse')
+                    test = mx.symbol._internal._square_sum(rsp_data, 
axis=axis, keepdims=keepdim)
+
+                    # check symbolic backward since ograd can be a rsp
+                    # and cannot be checked through check_numeric_gradient
+                    # because it will add a loss layer as the output layer
+                    # which makes ograd of the square_sum dense
+                    if axis == 1 and keepdims:
+                        dns_data = mx.sym.Variable('data')
+                        baseline = mx.sym.sum(mx.sym.square(dns_data), 
axis=axis, keepdims=keepdim)
+                        igrad_expected = mx.nd.empty(dns.shape)
+                        baseline_exec = baseline.bind(default_context(), 
args=[dns],
+                                                      
args_grad=[igrad_expected])
+                        baseline_exec.forward(is_train=True)
+                        baseline_exec.backward([ret_expected])
+                        check_symbolic_backward(test, [rsp], [ret], 
[igrad_expected.asnumpy()],
+                                                grad_stypes={'data': 
'row_sparse'})
+
+                    # check numeric gradient
+                    check_numeric_gradient(test, [rsp], 
grad_stype_dict={'data': 'row_sparse'},
+                                           atol=1e-2, rtol=0.1)
 
 def test_sparse_storage_fallback():
     """ test operators which don't implement FComputeEx or FStatefulComputeEx 
"""
-    def check_broadcast_add(shape, lhs_stype, rhs_stype):
-        lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
-        rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
-        lhs_nd = rand_ndarray(shape, lhs_stype)
-        rhs_nd = rand_ndarray(shape, rhs_stype)
-        lhs_dns = mx.nd.cast_storage(lhs_nd, stype='default')
-        rhs_dns = mx.nd.cast_storage(rhs_nd, stype='default')
+    if default_context().device_type == 'cpu':
+        def check_broadcast_add(shape, lhs_stype, rhs_stype):
+            lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
+            rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
+            lhs_nd = rand_ndarray(shape, lhs_stype)
+            rhs_nd = rand_ndarray(shape, rhs_stype)
+            lhs_dns = mx.nd.cast_storage(lhs_nd, stype='default')
+            rhs_dns = mx.nd.cast_storage(rhs_nd, stype='default')
+
+            out_dns = (lhs_dns + rhs_dns).asnumpy()
+            test = mx.symbol.broadcast_add(lhs, rhs)
+            location = {'lhs': lhs_nd, 'rhs': rhs_nd}
+            check_symbolic_forward(test, location, [out_dns])
+            check_numeric_gradient(test, location)
+            check_symbolic_backward(test, location, [out_dns], [out_dns, 
out_dns])
+
+        def np_softmax(x, axis=-1):
+            # fix for old numpy on Travis not supporting keepdims
+            # x = x - np.max(x, axis=-1, keepdims=True)
+            x = x - np.max(x, axis=axis, keepdims=True)
+            x = np.exp(x)
+            # x /= np.sum(x, axis=-1, keepdims=True)
+            x /= np.sum(x, axis=axis, keepdims=True)
+            return x
+
+        def check_softmax_with_shape(lhs_stype, rhs_stype, shape, 
preserve_shape=False):
+            # bind with label
+            ctx = default_context()
+            X = mx.symbol.Variable('X', stype=lhs_stype)
+            L = mx.symbol.Variable('L', stype=rhs_stype)
+            Y = mx.symbol.SoftmaxOutput(data=X, label=L, 
preserve_shape=preserve_shape)
+            x = rand_ndarray(shape, lhs_stype)
+            l = rand_ndarray(shape, rhs_stype)
+            l[:] = np_softmax(l.asnumpy())
+            grad = mx.nd.empty(shape, ctx=ctx)
+            exec1 = Y.bind(ctx, args = [x, l], args_grad = {'X': grad})
+            exec1.forward(is_train=True)
+            out = exec1.outputs[0].asnumpy()
+            assert_almost_equal(out, np_softmax(x.asnumpy()), rtol=1e-4)
+            exec1.backward()
+            assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - 
l.asnumpy(),
+                                rtol=1e-3, atol=1e-4)
 
-        out_dns = (lhs_dns + rhs_dns).asnumpy()
-        test = mx.symbol.broadcast_add(lhs, rhs)
-        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
-        check_symbolic_forward(test, location, [out_dns])
-        check_numeric_gradient(test, location)
-        check_symbolic_backward(test, location, [out_dns], [out_dns, out_dns])
-
-    def np_softmax(x, axis=-1):
-        # fix for old numpy on Travis not supporting keepdims
-        # x = x - np.max(x, axis=-1, keepdims=True)
-        x = x - np.max(x, axis=axis, keepdims=True)
-        x = np.exp(x)
-        # x /= np.sum(x, axis=-1, keepdims=True)
-        x /= np.sum(x, axis=axis, keepdims=True)
-        return x
-
-    def check_softmax_with_shape(lhs_stype, rhs_stype, shape, 
preserve_shape=False):
-        # bind with label
-        ctx = default_context()
-        X = mx.symbol.Variable('X', stype=lhs_stype)
-        L = mx.symbol.Variable('L', stype=rhs_stype)
-        Y = mx.symbol.SoftmaxOutput(data=X, label=L, 
preserve_shape=preserve_shape)
-        x = rand_ndarray(shape, lhs_stype)
-        l = rand_ndarray(shape, rhs_stype)
-        l[:] = np_softmax(l.asnumpy())
-        grad = mx.nd.empty(shape, ctx=ctx)
-        exec1 = Y.bind(ctx, args = [x, l], args_grad = {'X': grad})
-        exec1.forward(is_train=True)
-        out = exec1.outputs[0].asnumpy()
-        assert_almost_equal(out, np_softmax(x.asnumpy()), rtol=1e-4)
-        exec1.backward()
-        assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - 
l.asnumpy(),
-                            rtol=1e-3, atol=1e-4)
-
-    def check_concat(shape, lhs_stype, rhs_stype):
-        x = mx.symbol.Variable('x', stype=lhs_stype)
-        w = mx.symbol.Variable('w', stype=rhs_stype)
-        test = mx.sym.Concat(x, w)
-        x_nd = rand_ndarray(shape, lhs_stype)
-        w_nd = rand_ndarray(shape, rhs_stype)
-        location = {'x': x_nd, 'w': w_nd}
-        check_numeric_gradient(test, location)
+        def check_concat(shape, lhs_stype, rhs_stype):
+            x = mx.symbol.Variable('x', stype=lhs_stype)
+            w = mx.symbol.Variable('w', stype=rhs_stype)
+            test = mx.sym.Concat(x, w)
+            x_nd = rand_ndarray(shape, lhs_stype)
+            w_nd = rand_ndarray(shape, rhs_stype)
+            location = {'x': x_nd, 'w': w_nd}
+            check_numeric_gradient(test, location)
 
-    shape = rand_shape_2d()
-    stypes = ['default', 'csr', 'row_sparse']
-    for lhs in stypes:
-        for rhs in stypes:
-            check_broadcast_add(shape, lhs, rhs)
-            check_concat(shape, lhs, rhs)
-            check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False)
-            check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True)
+        shape = rand_shape_2d()
+        stypes = ['default', 'csr', 'row_sparse']
+        for lhs in stypes:
+            for rhs in stypes:
+                check_broadcast_add(shape, lhs, rhs)
+                check_concat(shape, lhs, rhs)
+                check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False)
+                check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True)
 
 
 def test_sparse_elementwise_sum():
-    def check_sparse_elementwise_sum_with_shape(stype, shape, n):
-        # forward
-        inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
-        out = mx.symbol.sparse.add_n(*inputs, name='esum')
-        arr = []
-        arr_grad = [mx.nd.empty(shape) for _ in range(n)]
-        densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5]
-        for i in range(n):
-            arr.append(rand_ndarray(shape, stype, np.random.randint(0, 
len(densities))))
-
-        exec1 = out.bind(default_context(),
-                         args=arr,
-                         args_grad=arr_grad)
-        exec1.forward(is_train=True)
-        out1 = exec1.outputs[0].asnumpy()
-        out = sum(a.asnumpy() for a in arr)
-        assert_almost_equal(out, out1)
-
-        out_grad = mx.nd.empty(shape)
-        out_grad[:] = np.random.uniform(-10, 10, shape)
-        # backward
-        exec1.backward([out_grad])
-        for a in arr_grad:
-            assert_almost_equal(a.asnumpy(), out_grad.asnumpy())
-
-    maxdim = 5
-    for dim in range(2, maxdim):
-        shape = tuple(np.random.randint(5, 10, size=dim))
-        check_sparse_elementwise_sum_with_shape('row_sparse', shape, 
np.random.randint(1, 9))
+    if default_context().device_type == 'cpu':
+        def check_sparse_elementwise_sum_with_shape(stype, shape, n):
+            # forward
+            inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
+            out = mx.symbol.sparse.add_n(*inputs, name='esum')
+            arr = []
+            arr_grad = [mx.nd.empty(shape) for _ in range(n)]
+            densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5]
+            for i in range(n):
+                arr.append(rand_ndarray(shape, stype, np.random.randint(0, 
len(densities))))
+
+            exec1 = out.bind(default_context(),
+                             args=arr,
+                             args_grad=arr_grad)
+            exec1.forward(is_train=True)
+            out1 = exec1.outputs[0].asnumpy()
+            out = sum(a.asnumpy() for a in arr)
+            assert_almost_equal(out, out1)
+
+            out_grad = mx.nd.empty(shape)
+            out_grad[:] = np.random.uniform(-10, 10, shape)
+            # backward
+            exec1.backward([out_grad])
+            for a in arr_grad:
+                assert_almost_equal(a.asnumpy(), out_grad.asnumpy())
+
+        maxdim = 5
+        for dim in range(2, maxdim):
+            shape = tuple(np.random.randint(5, 10, size=dim))
+            check_sparse_elementwise_sum_with_shape('row_sparse', shape, 
np.random.randint(1, 9))
 
 
 if __name__ == '__main__':

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to