Author: Brian Kearns <bdkea...@gmail.com>
Branch: 
Changeset: r70890:d91034c74551
Date: 2014-04-23 15:17 -0400
http://bitbucket.org/pypy/pypy/changeset/d91034c74551/

Log:    fix ufunc reduce with single axis tuple (issue1718)

diff --git a/pypy/module/micronumpy/compile.py 
b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -136,6 +136,11 @@
     def newcomplex(self, r, i):
         return ComplexObject(r, i)
 
+    def getitem(self, obj, index):
+        assert isinstance(obj, ListObject)
+        assert isinstance(index, IntObject)
+        return obj.items[index.intval]
+
     def listview(self, obj, number=-1):
         assert isinstance(obj, ListObject)
         if number != -1:
diff --git a/pypy/module/micronumpy/test/test_ndarray.py 
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1506,6 +1506,9 @@
         from numpypy import array, zeros
         a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
         assert a.max() == 5.7
+        assert a.max().shape == ()
+        assert a.max(axis=(0,)) == 5.7
+        assert a.max(axis=(0,)).shape == ()
         assert a.max(keepdims=True) == 5.7
         assert a.max(keepdims=True).shape == (1,)
         b = array([])
@@ -1521,6 +1524,9 @@
         from numpypy import array, zeros
         a = array([-1.2, 3.4, 5.7, -3.0, 2.7])
         assert a.min() == -3.0
+        assert a.min().shape == ()
+        assert a.min(axis=(0,)) == -3.0
+        assert a.min(axis=(0,)).shape == ()
         assert a.min(keepdims=True) == -3.0
         assert a.min(keepdims=True).shape == (1,)
         b = array([])
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -772,6 +772,7 @@
 
         a = zeros((2, 2)) + 1
         assert (add.reduce(a, axis=1) == [2, 2]).all()
+        assert (add.reduce(a, axis=(1,)) == [2, 2]).all()
         exc = raises(ValueError, add.reduce, a, axis=2)
         assert exc.value[0] == "'axis' entry is out of bounds"
 
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -178,6 +178,8 @@
         if space.is_none(w_axis):
             axis = maxint
         else:
+            if space.isinstance_w(w_axis, space.w_tuple) and 
space.len_w(w_axis) == 1:
+                w_axis = space.getitem(w_axis, space.wrap(0))
             axis = space.int_w(w_axis)
             if axis < -shapelen or axis >= shapelen:
                 raise oefmt(space.w_ValueError, "'axis' entry is out of 
bounds")
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to