piiswrong closed pull request #10767: add reverse option to ndarray inplace
reshape
URL: https://github.com/apache/incubator-mxnet/pull/10767
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 3f040515c2f..9ac90d68c67 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -663,6 +663,7 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
+ bool reverse,
NDArrayHandle *out);
/*!
* \brief get the shape of the array
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 361aa240daf..9fce52f36de 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -989,6 +989,19 @@ def reshape(self, *shape, **kwargs):
- input shape = (2,3,4), shape = (-4,1,2,-2), output shape
=(1,2,3,4)
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape =
(2,1,3,4)
+ - If the argument `reverse` is set to 1, then the special values
are inferred from right
+ to left.
+
+ Example::
+
+ - without reverse=1, for input shape = (10,5,4), shape = (-1,0),
output shape would be
+ (40,5).
+ - with reverse=1, output shape will be (50,4).
+
+ reverse : bool, default False
+ If true then the special values are inferred from right to left.
Only supported as
+ keyword argument.
+
Returns
-------
@@ -1029,18 +1042,19 @@ def reshape(self, *shape, **kwargs):
elif not shape:
shape = kwargs.get('shape')
assert shape, "Shape must be provided."
- if len(kwargs) != 1:
- raise TypeError("Only 'shape' is supported as keyword
argument. Got: {}."
- .format(', '.join(kwargs.keys())))
- else:
- assert not kwargs,\
- "Specifying both positional and keyword arguments is not
allowed in reshape."
+ if not all(k in ['shape', 'reverse'] for k in kwargs):
+ raise TypeError(
+ "Got unknown keywords in reshape: {}. " \
+ "Accepted keyword arguments are 'shape' and 'reverse'.".format(
+ ', '.join([k for k in kwargs if k not in ['shape',
'reverse']])))
+ reverse = kwargs.get('reverse', False)
handle = NDArrayHandle()
# Actual reshape
check_call(_LIB.MXNDArrayReshape64(self.handle,
len(shape),
c_array(ctypes.c_int64, shape),
+ reverse,
ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 34b4fd22f85..b3dcd6a65d9 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -431,12 +431,13 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
+ bool reverse,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
- TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), false);
+ TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(),
reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
diff --git a/tests/python/unittest/test_ndarray.py
b/tests/python/unittest/test_ndarray.py
index 030816ecbbc..9ff2f1af312 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -154,30 +154,23 @@ def test_ndarray_negate():
@with_seed()
def test_ndarray_reshape():
- tensor = mx.nd.array([[[1, 2], [3, 4]],
- [[5, 6], [7, 8]]])
- true_res = mx.nd.arange(8) + 1
- assert same(tensor.reshape((-1, )).asnumpy(), true_res.asnumpy())
- true_res = mx.nd.array([[1, 2, 3, 4],
- [5, 6, 7, 8]])
- assert same(tensor.reshape((2, -1)).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape((0, -1)).asnumpy(), true_res.asnumpy())
- true_res = mx.nd.array([[1, 2],
- [3, 4],
- [5, 6],
- [7, 8]])
- assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
- true_res = mx.nd.arange(8) + 1
+ tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
+ true_res = mx.nd.arange(30) + 1
+ assert same(tensor.reshape((-1,)).asnumpy(), true_res.asnumpy())
+ assert same(tensor.reshape((2, -1)).asnumpy(), true_res.reshape(2,
15).asnumpy())
+ assert same(tensor.reshape((0, -1)).asnumpy(), true_res.reshape(2,
15).asnumpy())
+ assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.reshape(15,
2).asnumpy())
+ assert same(tensor.reshape(6, 5).asnumpy(), true_res.reshape(6,
5).asnumpy())
+ assert same(tensor.reshape(-1, 2).asnumpy(), true_res.reshape(15,
2).asnumpy())
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
- assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())
-
- assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2,
4).asnumpy())
- assert same(tensor.reshape(-1, 4).asnumpy(), true_res.reshape(2,
4).asnumpy())
- assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 2,
2).asnumpy())
- assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(4,
2).asnumpy())
- assert same(tensor.reshape(-1, 4).reshape(0, -4, 2, -1).asnumpy(),
true_res.reshape(2, 2, 2).asnumpy())
+ assert same(tensor.reshape(30).asnumpy(), true_res.asnumpy())
+ assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2,
15).asnumpy())
+ assert same(tensor.reshape(-1, 6).asnumpy(), true_res.reshape(5,
6).asnumpy())
+ assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 3,
5).asnumpy())
+ assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(6,
5).asnumpy())
+ assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(),
true_res.reshape(2, 3, 5).asnumpy())
+ assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10,
3).asnumpy())
+ assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(),
true_res.reshape(6, 5).asnumpy())
@with_seed()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services