Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r70890:d91034c74551 Date: 2014-04-23 15:17 -0400 http://bitbucket.org/pypy/pypy/changeset/d91034c74551/
Log: fix ufunc reduce with single axis tuple (issue1718) diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py --- a/pypy/module/micronumpy/compile.py +++ b/pypy/module/micronumpy/compile.py @@ -136,6 +136,11 @@ def newcomplex(self, r, i): return ComplexObject(r, i) + def getitem(self, obj, index): + assert isinstance(obj, ListObject) + assert isinstance(index, IntObject) + return obj.items[index.intval] + def listview(self, obj, number=-1): assert isinstance(obj, ListObject) if number != -1: diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py --- a/pypy/module/micronumpy/test/test_ndarray.py +++ b/pypy/module/micronumpy/test/test_ndarray.py @@ -1506,6 +1506,9 @@ from numpypy import array, zeros a = array([-1.2, 3.4, 5.7, -3.0, 2.7]) assert a.max() == 5.7 + assert a.max().shape == () + assert a.max(axis=(0,)) == 5.7 + assert a.max(axis=(0,)).shape == () assert a.max(keepdims=True) == 5.7 assert a.max(keepdims=True).shape == (1,) b = array([]) @@ -1521,6 +1524,9 @@ from numpypy import array, zeros a = array([-1.2, 3.4, 5.7, -3.0, 2.7]) assert a.min() == -3.0 + assert a.min().shape == () + assert a.min(axis=(0,)) == -3.0 + assert a.min(axis=(0,)).shape == () assert a.min(keepdims=True) == -3.0 assert a.min(keepdims=True).shape == (1,) b = array([]) diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -772,6 +772,7 @@ a = zeros((2, 2)) + 1 assert (add.reduce(a, axis=1) == [2, 2]).all() + assert (add.reduce(a, axis=(1,)) == [2, 2]).all() exc = raises(ValueError, add.reduce, a, axis=2) assert exc.value[0] == "'axis' entry is out of bounds" diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -178,6 +178,8 @@ if space.is_none(w_axis): axis = maxint else: + if space.isinstance_w(w_axis, space.w_tuple) and space.len_w(w_axis) == 1: + w_axis = space.getitem(w_axis, space.wrap(0)) axis = space.int_w(w_axis) if axis < -shapelen or axis >= shapelen: raise oefmt(space.w_ValueError, "'axis' entry is out of bounds") _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit