piiswrong closed pull request #9790: make array.reshape compatible with numpy
URL: https://github.com/apache/incubator-mxnet/pull/9790
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index b089cc5117..4c5273fd40 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -926,12 +926,12 @@ def _at(self, idx):
             self.handle, mx_uint(idx), ctypes.byref(handle)))
         return NDArray(handle=handle, writable=self.writable)
 
-    def reshape(self, shape):
+    def reshape(self, *shape, **kwargs):
         """Returns a **view** of this array with a new shape without altering 
any data.
 
         Parameters
         ----------
-        shape : tuple of int
+        shape : tuple of int, or n ints
             The new shape should not change the array size, namely
             ``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``.
 
@@ -960,6 +960,11 @@ def reshape(self, shape):
                [ 4.,  5.]], dtype=float32)
         >>> y = x.reshape((3,-1))
         >>> y.asnumpy()
+        array([[ 0.,  1.],
+               [ 2.,  3.],
+               [ 4.,  5.]], dtype=float32)
+        >>> y = x.reshape(3,2)
+        >>> y.asnumpy()
         array([[ 0.,  1.],
                [ 2.,  3.],
                [ 4.,  5.]], dtype=float32)
@@ -968,6 +973,17 @@ def reshape(self, shape):
         array([[-1., -1., -1.],
                [-1., -1., -1.]], dtype=float32)
         """
+        if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
+            shape = shape[0]
+        elif not shape:
+            shape = kwargs.get('shape')
+            assert shape, "Shape must be provided."
+            if len(kwargs) != 1:
+                raise TypeError("Only 'shape' is supported as keyword 
argument. Got: {}."
+                                .format(', '.join(kwargs.keys())))
+        else:
+            assert not kwargs,\
+                "Specifying both positional and keyword arguments is not 
allowed in reshape."
         handle = NDArrayHandle()
 
         # Actual reshape
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index 62ee32a9e3..c65d7ce408 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -138,7 +138,7 @@ def _at(self, idx):
     def _slice(self, start, stop):
         raise NotSupportedForSparseNDArray(self._slice, None, start, stop)
 
-    def reshape(self, shape):
+    def reshape(self, *shape, **kwargs):
         raise NotSupportedForSparseNDArray(self.reshape, None, shape)
 
     @property
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 22ff6e8cf5..78804244a2 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -154,6 +154,11 @@ def test_ndarray_reshape():
                              [5, 6],
                              [7, 8]])
     assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
+    assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
+    assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
+    true_res = mx.nd.arange(8) + 1
+    assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
+    assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())
 
 
 def test_ndarray_choose():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to