Author: mattip <matti.pi...@gmail.com> Branch: Changeset: r83278:e7bacd0b61e2 Date: 2016-03-22 21:55 +0200 http://bitbucket.org/pypy/pypy/changeset/e7bacd0b61e2/
Log: Merged in sergem/pypy/fix_transpose_for_list_v3 (pull request #420) Fixed ndarray.transpose when argument is a list or an array diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py --- a/pypy/module/micronumpy/ndarray.py +++ b/pypy/module/micronumpy/ndarray.py @@ -502,29 +502,34 @@ return W_NDimArray(self.implementation.transpose(self, axes)) def descr_transpose(self, space, args_w): - if len(args_w) == 1 and space.isinstance_w(args_w[0], space.w_tuple): - args_w = space.fixedview(args_w[0]) - if (len(args_w) == 0 or - len(args_w) == 1 and space.is_none(args_w[0])): + if len(args_w) == 0 or len(args_w) == 1 and space.is_none(args_w[0]): return self.descr_get_transpose(space) else: - if len(args_w) != self.ndims(): - raise oefmt(space.w_ValueError, "axes don't match array") - axes = [] - axes_seen = [False] * self.ndims() - for w_arg in args_w: - try: - axis = support.index_w(space, w_arg) - except OperationError: - raise oefmt(space.w_TypeError, "an integer is required") - if axis < 0 or axis >= self.ndims(): - raise oefmt(space.w_ValueError, "invalid axis for this array") - if axes_seen[axis] is True: - raise oefmt(space.w_ValueError, "repeated axis in transpose") - axes.append(axis) - axes_seen[axis] = True - return self.descr_get_transpose(space, axes) + if len(args_w) > 1: + axes = args_w + else: # Iterable in the only argument (len(arg_w) == 1 and arg_w[0] is not None) + axes = space.fixedview(args_w[0]) + axes = self._checked_axes(axes, space) + return self.descr_get_transpose(space, axes) + + def _checked_axes(self, axes_raw, space): + if len(axes_raw) != self.ndims(): + raise oefmt(space.w_ValueError, "axes don't match array") + axes = [] + axes_seen = [False] * self.ndims() + for elem in axes_raw: + try: + axis = support.index_w(space, elem) + except OperationError: + raise oefmt(space.w_TypeError, "an integer is required") + if axis < 0 or axis >= self.ndims(): + raise oefmt(space.w_ValueError, "invalid axis for this array") + if axes_seen[axis] is True: + raise oefmt(space.w_ValueError, "repeated axis in transpose") + axes.append(axis) + axes_seen[axis] = True + return axes @unwrap_spec(axis1=int, axis2=int) def descr_swapaxes(self, space, axis1, axis2): diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py --- a/pypy/module/micronumpy/test/test_ndarray.py +++ b/pypy/module/micronumpy/test/test_ndarray.py @@ -2960,6 +2960,36 @@ assert (a.transpose() == b).all() assert (a.transpose(None) == b).all() + def test_transpose_arg_tuple(self): + import numpy as np + a = np.arange(24).reshape(2, 3, 4) + transpose_args = a.transpose(1, 2, 0) + + transpose_test = a.transpose((1, 2, 0)) + + assert transpose_test.shape == (3, 4, 2) + assert (transpose_args == transpose_test).all() + + def test_transpose_arg_list(self): + import numpy as np + a = np.arange(24).reshape(2, 3, 4) + transpose_args = a.transpose(1, 2, 0) + + transpose_test = a.transpose([1, 2, 0]) + + assert transpose_test.shape == (3, 4, 2) + assert (transpose_args == transpose_test).all() + + def test_transpose_arg_array(self): + import numpy as np + a = np.arange(24).reshape(2, 3, 4) + transpose_args = a.transpose(1, 2, 0) + + transpose_test = a.transpose(np.array([1, 2, 0])) + + assert transpose_test.shape == (3, 4, 2) + assert (transpose_args == transpose_test).all() + def test_transpose_error(self): import numpy as np a = np.arange(24).reshape(2, 3, 4) @@ -2968,6 +2998,11 @@ raises(ValueError, a.transpose, 1, 0, 1) raises(TypeError, a.transpose, 1, 0, '2') + def test_transpose_unexpected_argument(self): + import numpy as np + a = np.array([[1, 2], [3, 4], [5, 6]]) + raises(TypeError, 'a.transpose(axes=(1,2,0))') + def test_flatiter(self): from numpy import array, flatiter, arange, zeros a = array([[10, 30], [40, 60]]) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit