Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r68011:f85795564612
Date: 2013-11-13 21:33 -0500
http://bitbucket.org/pypy/pypy/changeset/f85795564612/

Log:    support axis argument for array.squeeze

diff --git a/pypy/module/micronumpy/conversion_utils.py 
b/pypy/module/micronumpy/conversion_utils.py
--- a/pypy/module/micronumpy/conversion_utils.py
+++ b/pypy/module/micronumpy/conversion_utils.py
@@ -40,3 +40,24 @@
         else:
             raise OperationError(space.w_TypeError, space.wrap(
                 "order not understood"))
+
+def multi_axis_converter(space, w_axis, ndim):
+    if space.is_none(w_axis):
+        return [True] * ndim
+    out = [False] * ndim
+    if not space.isinstance_w(w_axis, space.w_tuple):
+        w_axis = space.newtuple([w_axis])
+    for w_item in space.fixedview(w_axis):
+        item = space.int_w(w_item)
+        axis = item
+        if axis < 0:
+            axis += ndim
+        if axis < 0 or axis >= ndim:
+            raise OperationError(space.w_ValueError, space.wrap(
+                "'axis' entry %d is out of bounds [-%d, %d)" %
+                (item, ndim, ndim)))
+        if out[axis]:
+            raise OperationError(space.w_ValueError, space.wrap(
+                "duplicate value in 'axis'"))
+        out[axis] = True
+    return out
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -19,7 +19,7 @@
 from rpython.rlib import jit
 from rpython.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.arrayimpl.base import BaseArrayImplementation
-from pypy.module.micronumpy.conversion_utils import order_converter
+from pypy.module.micronumpy.conversion_utils import order_converter, 
multi_axis_converter
 from pypy.module.micronumpy.constants import *
 
 def _find_shape(space, w_size, dtype):
@@ -692,11 +692,20 @@
         return self.implementation.sort(space, w_axis, w_order)
 
     def descr_squeeze(self, space, w_axis=None):
+        cur_shape = self.get_shape()
         if not space.is_none(w_axis):
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                "axis unsupported for squeeze"))
-        cur_shape = self.get_shape()
-        new_shape = [s for s in cur_shape if s != 1]
+            axes = multi_axis_converter(space, w_axis, len(cur_shape))
+            new_shape = []
+            for i in range(len(cur_shape)):
+                if axes[i]:
+                    if cur_shape[i] != 1:
+                        raise OperationError(space.w_ValueError, space.wrap(
+                            "cannot select an axis to squeeze out " \
+                            "which has size greater than one"))
+                else:
+                    new_shape.append(cur_shape[i])
+        else:
+            new_shape = [s for s in cur_shape if s != 1]
         if len(cur_shape) == len(new_shape):
             return self
         return wrap_impl(space, space.type(self), self,
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1746,6 +1746,18 @@
         assert (b == a).all()
         b[1] = -1
         assert a[0][1] == -1
+        a = np.arange(9).reshape((3, 1, 3, 1))
+        b = a.squeeze(1)
+        assert b.shape == (3, 3, 1)
+        b = a.squeeze((1,))
+        assert b.shape == (3, 3, 1)
+        b = a.squeeze((1, -1))
+        assert b.shape == (3, 3)
+        exc = raises(ValueError, a.squeeze, 5)
+        assert exc.value.message == "'axis' entry 5 is out of bounds [-4, 4)"
+        exc = raises(ValueError, a.squeeze, 0)
+        assert exc.value.message == "cannot select an axis to squeeze out " \
+                                    "which has size greater than one"
 
     def test_swapaxes(self):
         from numpypy import array
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to