This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a7e7cdc  add ctx for rand_ndarray and rand_sparse_ndarray (#14966)
a7e7cdc is described below

commit a7e7cdc0c4ffedcea2cb2ad6982b341f35412cb1
Author: Hao Jin <[email protected]>
AuthorDate: Sat May 18 11:13:28 2019 -0700

    add ctx for rand_ndarray and rand_sparse_ndarray (#14966)
---
 python/mxnet/test_utils.py | 21 ++++++++++++---------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 7b46be4..fb40474 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -260,7 +260,7 @@ def assign_each2(input1, input2, function):
 
 def rand_sparse_ndarray(shape, stype, density=None, dtype=None, 
distribution=None,
                         data_init=None, rsp_indices=None, modifier_func=None,
-                        shuffle_csr_indices=False):
+                        shuffle_csr_indices=False, ctx=None):
     """Generate a random sparse ndarray. Returns the ndarray, value(np) and 
indices(np)
 
     Parameters
@@ -301,6 +301,7 @@ def rand_sparse_ndarray(shape, stype, density=None, 
dtype=None, distribution=Non
     >>> assert(row4nnz == 2*row3nnz)
 
     """
+    ctx = ctx if ctx else default_context()
     density = rnd.rand() if density is None else density
     dtype = default_dtype() if dtype is None else dtype
     distribution = "uniform" if distribution is None else distribution
@@ -315,7 +316,7 @@ def rand_sparse_ndarray(shape, stype, density=None, 
dtype=None, distribution=Non
             idx_sample = rnd.rand(shape[0])
             indices = np.argwhere(idx_sample < density).flatten()
         if indices.shape[0] == 0:
-            result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype)
+            result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype, 
ctx=ctx)
             return result, (np.array([], dtype=dtype), np.array([]))
         # generate random values
         val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
@@ -326,17 +327,17 @@ def rand_sparse_ndarray(shape, stype, density=None, 
dtype=None, distribution=Non
         if modifier_func is not None:
             val = assign_each(val, modifier_func)
 
-        arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, 
dtype=dtype)
+        arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, 
dtype=dtype, ctx=ctx)
         return arr, (val, indices)
     elif stype == 'csr':
         assert len(shape) == 2
         if distribution == "uniform":
             csr = _get_uniform_dataset_csr(shape[0], shape[1], density,
                                            data_init=data_init,
-                                           
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype)
+                                           
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx)
             return csr, (csr.indptr, csr.indices, csr.data)
         elif distribution == "powerlaw":
-            csr = _get_powerlaw_dataset_csr(shape[0], shape[1], 
density=density, dtype=dtype)
+            csr = _get_powerlaw_dataset_csr(shape[0], shape[1], 
density=density, dtype=dtype).as_in_context(ctx)
             return csr, (csr.indptr, csr.indices, csr.data)
         else:
             assert(False), "Distribution not supported: %s" % (distribution)
@@ -345,15 +346,17 @@ def rand_sparse_ndarray(shape, stype, density=None, 
dtype=None, distribution=Non
         assert(False), "unknown storage type"
         return False
 
-def rand_ndarray(shape, stype='default', density=None, dtype=None,
-                 modifier_func=None, shuffle_csr_indices=False, 
distribution=None):
+def rand_ndarray(shape, stype='default', density=None, dtype=None, 
modifier_func=None,
+                 shuffle_csr_indices=False, distribution=None, ctx=None):
+    """Generate a random sparse ndarray. Returns the generated ndarray."""
+    ctx = ctx if ctx else default_context()
     if stype == 'default':
-        arr = mx.nd.array(random_arrays(shape), dtype=dtype)
+        arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx)
     else:
         arr, _ = rand_sparse_ndarray(shape, stype, density=density,
                                      modifier_func=modifier_func, dtype=dtype,
                                      shuffle_csr_indices=shuffle_csr_indices,
-                                     distribution=distribution)
+                                     distribution=distribution, ctx=ctx)
     return arr
 
 

Reply via email to