Author: Brian Kearns <[email protected]>
Branch:
Changeset: r68011:f85795564612
Date: 2013-11-13 21:33 -0500
http://bitbucket.org/pypy/pypy/changeset/f85795564612/
Log: support axis argument for array.squeeze
diff --git a/pypy/module/micronumpy/conversion_utils.py
b/pypy/module/micronumpy/conversion_utils.py
--- a/pypy/module/micronumpy/conversion_utils.py
+++ b/pypy/module/micronumpy/conversion_utils.py
@@ -40,3 +40,24 @@
else:
raise OperationError(space.w_TypeError, space.wrap(
"order not understood"))
+
+def multi_axis_converter(space, w_axis, ndim):
+ if space.is_none(w_axis):
+ return [True] * ndim
+ out = [False] * ndim
+ if not space.isinstance_w(w_axis, space.w_tuple):
+ w_axis = space.newtuple([w_axis])
+ for w_item in space.fixedview(w_axis):
+ item = space.int_w(w_item)
+ axis = item
+ if axis < 0:
+ axis += ndim
+ if axis < 0 or axis >= ndim:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "'axis' entry %d is out of bounds [-%d, %d)" %
+ (item, ndim, ndim)))
+ if out[axis]:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "duplicate value in 'axis'"))
+ out[axis] = True
+ return out
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -19,7 +19,7 @@
from rpython.rlib import jit
from rpython.rlib.rstring import StringBuilder
from pypy.module.micronumpy.arrayimpl.base import BaseArrayImplementation
-from pypy.module.micronumpy.conversion_utils import order_converter
+from pypy.module.micronumpy.conversion_utils import order_converter,
multi_axis_converter
from pypy.module.micronumpy.constants import *
def _find_shape(space, w_size, dtype):
@@ -692,11 +692,20 @@
return self.implementation.sort(space, w_axis, w_order)
def descr_squeeze(self, space, w_axis=None):
+ cur_shape = self.get_shape()
if not space.is_none(w_axis):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "axis unsupported for squeeze"))
- cur_shape = self.get_shape()
- new_shape = [s for s in cur_shape if s != 1]
+ axes = multi_axis_converter(space, w_axis, len(cur_shape))
+ new_shape = []
+ for i in range(len(cur_shape)):
+ if axes[i]:
+ if cur_shape[i] != 1:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "cannot select an axis to squeeze out " \
+ "which has size greater than one"))
+ else:
+ new_shape.append(cur_shape[i])
+ else:
+ new_shape = [s for s in cur_shape if s != 1]
if len(cur_shape) == len(new_shape):
return self
return wrap_impl(space, space.type(self), self,
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1746,6 +1746,18 @@
assert (b == a).all()
b[1] = -1
assert a[0][1] == -1
+ a = np.arange(9).reshape((3, 1, 3, 1))
+ b = a.squeeze(1)
+ assert b.shape == (3, 3, 1)
+ b = a.squeeze((1,))
+ assert b.shape == (3, 3, 1)
+ b = a.squeeze((1, -1))
+ assert b.shape == (3, 3)
+ exc = raises(ValueError, a.squeeze, 5)
+ assert exc.value.message == "'axis' entry 5 is out of bounds [-4, 4)"
+ exc = raises(ValueError, a.squeeze, 0)
+ assert exc.value.message == "cannot select an axis to squeeze out " \
+ "which has size greater than one"
def test_swapaxes(self):
from numpypy import array
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit