Author: Michael Cheng <coolbutusel...@gmail.com> Branch: Changeset: r54360:886d352cf776 Date: 2012-04-14 20:04 +1000 http://bitbucket.org/pypy/pypy/changeset/886d352cf776/
Log: swapaxes for numpypy diff --git a/lib_pypy/numpypy/core/fromnumeric.py b/lib_pypy/numpypy/core/fromnumeric.py --- a/lib_pypy/numpypy/core/fromnumeric.py +++ b/lib_pypy/numpypy/core/fromnumeric.py @@ -411,7 +411,8 @@ [3, 7]]]) """ - raise NotImplementedError('Waiting on interp level method') + swapaxes = a.swapaxes + return swapaxes(axis1, axis2) def transpose(a, axes=None): diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -513,7 +513,30 @@ arr = concrete.copy(space) arr.setshape(space, new_shape) return arr - + + @unwrap_spec(axis1=int, axis2=int) + def descr_swapaxes(self, space, axis1, axis2): + """a.swapaxes(axis1, axis2) + + Return a view of the array with `axis1` and `axis2` interchanged. + + Refer to `numpy.swapaxes` for full documentation. + + See Also + -------- + numpy.swapaxes : equivalent function + """ + concrete = self.get_concrete() + shape = concrete.shape[:] + strides = concrete.strides[:] + backstrides = concrete.backstrides[:] + shape[axis1], shape[axis2] = shape[axis2], shape[axis1] + strides[axis1], strides[axis2] = strides[axis2], strides[axis1] + backstrides[axis1], backstrides[axis2] = backstrides[axis2], backstrides[axis1] + arr = W_NDimSlice(concrete.start, strides, + backstrides, shape, concrete) + return space.wrap(arr) + def descr_tolist(self, space): if len(self.shape) == 0: assert isinstance(self, Scalar) @@ -1412,6 +1435,7 @@ copy = interp2app(BaseArray.descr_copy), flatten = interp2app(BaseArray.descr_flatten), reshape = interp2app(BaseArray.descr_reshape), + swapaxes = interp2app(BaseArray.descr_swapaxes), tolist = interp2app(BaseArray.descr_tolist), take = interp2app(BaseArray.descr_take), compress = interp2app(BaseArray.descr_compress), diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1410,6 +1410,35 @@ assert (array([1, 2]).repeat(2) == array([1, 1, 2, 2])).all() + def test_swapaxes(self): + from _numpypy import array + # testcases from numpy docstring + x = array([[1, 2, 3]]) + assert (x.swapaxes(0, 1) == array([[1], [2], [3]])).all() + x = array([[[0,1],[2,3]],[[4,5],[6,7]]]) # shape = (2, 2, 2) + assert (x.swapaxes(0, 2) == array([[[0, 4], [2, 6]], + [[1, 5], [3, 7]]])).all() + assert (x.swapaxes(0, 1) == array([[[0, 1], [4, 5]], + [[2, 3], [6, 7]]])).all() + assert (x.swapaxes(1, 2) == array([[[0, 2], [1, 3]], + [[4, 6],[5, 7]]])).all() + + # more complex shape i.e. (2, 2, 3) + x = array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + assert (x.swapaxes(0, 1) == array([[[1, 2, 3], [7, 8, 9]], + [[4, 5, 6], [10, 11, 12]]])).all() + assert (x.swapaxes(0, 2) == array([[[1, 7], [4, 10]], [[2, 8], [5, 11]], + [[3, 9], [6, 12]]])).all() + assert (x.swapaxes(1, 2) == array([[[1, 4], [2, 5], [3, 6]], + [[7, 10], [8, 11],[9, 12]]])).all() + + # test slice + assert (x[0:1,0:2].swapaxes(0,2) == array([[[1], [4]], [[2], [5]], + [[3], [6]]])).all() + # test virtual + assert ((x + x).swapaxes(0,1) == array([[[ 2, 4, 6], [14, 16, 18]], + [[ 8, 10, 12], [20, 22, 24]]])).all() + class AppTestMultiDim(BaseNumpyAppTest): def test_init(self): import _numpypy diff --git a/pypy/module/test_lib_pypy/numpypy/core/test_fromnumeric.py b/pypy/module/test_lib_pypy/numpypy/core/test_fromnumeric.py --- a/pypy/module/test_lib_pypy/numpypy/core/test_fromnumeric.py +++ b/pypy/module/test_lib_pypy/numpypy/core/test_fromnumeric.py @@ -136,4 +136,11 @@ raises(NotImplementedError, "transpose(x, axes=(1, 0, 2))") # x = ones((1, 2, 3)) # assert transpose(x, (1, 0, 2)).shape == (2, 1, 3) - + + def test_fromnumeric(self): + from numpypy import array, swapaxes + x = array([[1,2,3]]) + assert (swapaxes(x,0,1) == array([[1], [2], [3]])).all() + x = array([[[0,1],[2,3]],[[4,5],[6,7]]]) + assert (swapaxes(x,0,2) == array([[[0, 4], [2, 6]], + [[1, 5], [3, 7]]])).all() _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit