This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 9348a3a make array.reshape compatible with numpy (#9790) 9348a3a is described below commit 9348a3ab9491a2835ab3b9652e2a1f2750ac1823 Author: Sheng Zha <s...@users.noreply.github.com> AuthorDate: Mon Feb 19 11:50:36 2018 -0800 make array.reshape compatible with numpy (#9790) * make array.reshape compatible with numpy * update * add exception when both *args and **kwargs are specified * update --- python/mxnet/ndarray/ndarray.py | 20 ++++++++++++++++++-- python/mxnet/ndarray/sparse.py | 2 +- tests/python/unittest/test_ndarray.py | 5 +++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index b089cc5..4c5273f 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -926,12 +926,12 @@ fixed-size items. self.handle, mx_uint(idx), ctypes.byref(handle))) return NDArray(handle=handle, writable=self.writable) - def reshape(self, shape): + def reshape(self, *shape, **kwargs): """Returns a **view** of this array with a new shape without altering any data. Parameters ---------- - shape : tuple of int + shape : tuple of int, or n ints The new shape should not change the array size, namely ``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``. @@ -963,11 +963,27 @@ fixed-size items. array([[ 0., 1.], [ 2., 3.], [ 4., 5.]], dtype=float32) + >>> y = x.reshape(3,2) + >>> y.asnumpy() + array([[ 0., 1.], + [ 2., 3.], + [ 4., 5.]], dtype=float32) >>> y[:] = -1 >>> x.asnumpy() array([[-1., -1., -1.], [-1., -1., -1.]], dtype=float32) """ + if len(shape) == 1 and isinstance(shape[0], (list, tuple)): + shape = shape[0] + elif not shape: + shape = kwargs.get('shape') + assert shape, "Shape must be provided." + if len(kwargs) != 1: + raise TypeError("Only 'shape' is supported as keyword argument. Got: {}." + .format(', '.join(kwargs.keys()))) + else: + assert not kwargs,\ + "Specifying both positional and keyword arguments is not allowed in reshape." handle = NDArrayHandle() # Actual reshape diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index 62ee32a..c65d7ce 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -138,7 +138,7 @@ class BaseSparseNDArray(NDArray): def _slice(self, start, stop): raise NotSupportedForSparseNDArray(self._slice, None, start, stop) - def reshape(self, shape): + def reshape(self, *shape, **kwargs): raise NotSupportedForSparseNDArray(self.reshape, None, shape) @property diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 9fae7ab..6c10487 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -158,6 +158,11 @@ def test_ndarray_reshape(): [5, 6], [7, 8]]) assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy()) + assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy()) + assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy()) + true_res = mx.nd.arange(8) + 1 + assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy()) + assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy()) @with_seed() -- To stop receiving notification emails like this one, please contact j...@apache.org.