Author: Ronan Lamy <[email protected]>
Branch: 
Changeset: r78490:0e674dc322e0
Date: 2015-07-07 18:10 +0100
http://bitbucket.org/pypy/pypy/changeset/0e674dc322e0/

Log:    Handle errors in np.transpose()

diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -403,12 +403,21 @@
                 len(args_w) == 1 and space.is_none(args_w[0])):
             return self.descr_get_transpose(space)
         else:
+            if len(args_w) != self.ndims():
+                raise oefmt(space.w_ValueError, "axes don't match array")
             axes = []
+            axes_seen = [False] * self.ndims()
             for w_arg in args_w:
                 try:
-                    axes.append(support.index_w(space, w_arg))
+                    axis = support.index_w(space, w_arg)
                 except OperationError:
                     raise oefmt(space.w_TypeError, "an integer is required")
+                if axis < 0 or axis >= self.ndims():
+                    raise oefmt(space.w_ValueError, "invalid axis for this 
array")
+                if axes_seen[axis] is True:
+                    raise oefmt(space.w_ValueError, "repeated axis in 
transpose")
+                axes.append(axis)
+                axes_seen[axis] = True
             return self.descr_get_transpose(space, axes)
 
 
diff --git a/pypy/module/micronumpy/test/test_ndarray.py 
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -2781,6 +2781,14 @@
         assert (a.transpose() == b).all()
         assert (a.transpose(None) == b).all()
 
+    def test_transpose_error(self):
+        import numpy as np
+        a = np.arange(24).reshape(2, 3, 4)
+        raises(ValueError, a.transpose, 2, 1)
+        raises(ValueError, a.transpose, 1, 0, 3)
+        raises(ValueError, a.transpose, 1, 0, 1)
+        raises(TypeError, a.transpose, 1, 0, '2')
+
     def test_flatiter(self):
         from numpy import array, flatiter, arange, zeros
         a = array([[10, 30], [40, 60]])
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to