This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a7e7cdc add ctx for rand_ndarray and rand_sparse_ndarray (#14966)
a7e7cdc is described below
commit a7e7cdc0c4ffedcea2cb2ad6982b341f35412cb1
Author: Hao Jin <[email protected]>
AuthorDate: Sat May 18 11:13:28 2019 -0700
add ctx for rand_ndarray and rand_sparse_ndarray (#14966)
---
python/mxnet/test_utils.py | 21 ++++++++++++---------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 7b46be4..fb40474 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -260,7 +260,7 @@ def assign_each2(input1, input2, function):
def rand_sparse_ndarray(shape, stype, density=None, dtype=None,
distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
- shuffle_csr_indices=False):
+ shuffle_csr_indices=False, ctx=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and
indices(np)
Parameters
@@ -301,6 +301,7 @@ def rand_sparse_ndarray(shape, stype, density=None,
dtype=None, distribution=Non
>>> assert(row4nnz == 2*row3nnz)
"""
+ ctx = ctx if ctx else default_context()
density = rnd.rand() if density is None else density
dtype = default_dtype() if dtype is None else dtype
distribution = "uniform" if distribution is None else distribution
@@ -315,7 +316,7 @@ def rand_sparse_ndarray(shape, stype, density=None,
dtype=None, distribution=Non
idx_sample = rnd.rand(shape[0])
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
- result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype)
+ result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype,
ctx=ctx)
return result, (np.array([], dtype=dtype), np.array([]))
# generate random values
val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
@@ -326,17 +327,17 @@ def rand_sparse_ndarray(shape, stype, density=None,
dtype=None, distribution=Non
if modifier_func is not None:
val = assign_each(val, modifier_func)
- arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape,
dtype=dtype)
+ arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape,
dtype=dtype, ctx=ctx)
return arr, (val, indices)
elif stype == 'csr':
assert len(shape) == 2
if distribution == "uniform":
csr = _get_uniform_dataset_csr(shape[0], shape[1], density,
data_init=data_init,
-
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype)
+
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
elif distribution == "powerlaw":
- csr = _get_powerlaw_dataset_csr(shape[0], shape[1],
density=density, dtype=dtype)
+ csr = _get_powerlaw_dataset_csr(shape[0], shape[1],
density=density, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
else:
assert(False), "Distribution not supported: %s" % (distribution)
@@ -345,15 +346,17 @@ def rand_sparse_ndarray(shape, stype, density=None,
dtype=None, distribution=Non
assert(False), "unknown storage type"
return False
-def rand_ndarray(shape, stype='default', density=None, dtype=None,
- modifier_func=None, shuffle_csr_indices=False,
distribution=None):
+def rand_ndarray(shape, stype='default', density=None, dtype=None,
modifier_func=None,
+ shuffle_csr_indices=False, distribution=None, ctx=None):
+ """Generate a random sparse ndarray. Returns the generated ndarray."""
+ ctx = ctx if ctx else default_context()
if stype == 'default':
- arr = mx.nd.array(random_arrays(shape), dtype=dtype)
+ arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx)
else:
arr, _ = rand_sparse_ndarray(shape, stype, density=density,
modifier_func=modifier_func, dtype=dtype,
shuffle_csr_indices=shuffle_csr_indices,
- distribution=distribution)
+ distribution=distribution, ctx=ctx)
return arr