Author: Matti Picus <matti.pi...@gmail.com>
Branch: 
Changeset: r87828:d86329133431
Date: 2016-10-16 14:30 +0300
http://bitbucket.org/pypy/pypy/changeset/d86329133431/

Log:    test, fix for returning scalar from reduce function, fixes
        pypy/numpy #57

diff --git a/pypy/module/micronumpy/test/dummy_module.py 
b/pypy/module/micronumpy/test/dummy_module.py
--- a/pypy/module/micronumpy/test/dummy_module.py
+++ b/pypy/module/micronumpy/test/dummy_module.py
@@ -27,7 +27,8 @@
 globals()['uint'] = dtype('uint').type
 
 types = ['Generic', 'Number', 'Integer', 'SignedInteger', 'UnsignedInteger',
-         'Inexact', 'Floating', 'ComplexFloating', 'Flexible', 'Character']
+         'Inexact', 'Floating', 'ComplexFloating', 'Flexible', 'Character',
+        ]
 for t in types:
     globals()[t.lower()] = typeinfo[t]
 
@@ -40,4 +41,4 @@
     return a
 
 def isscalar(a):
-    return type(a) in [typeinfo[t] for t in types]
+    return any([isinstance(a, typeinfo[t]) for t in types])
diff --git a/pypy/module/micronumpy/test/test_ndarray.py 
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1486,7 +1486,7 @@
         assert d[1] == 12
 
     def test_sum(self):
-        from numpy import array, zeros, float16, complex64, str_
+        from numpy import array, zeros, float16, complex64, str_, isscalar, add
         a = array(range(5))
         assert a.sum() == 10
         assert a[:4].sum() == 6
@@ -1515,6 +1515,13 @@
 
         assert list(zeros((0, 2)).sum(axis=1)) == []
 
+        a = array([1, 2, 3, 4]).sum()
+        s = isscalar(a)
+        assert s is True
+        a = add.reduce([1.0, 2, 3, 4])
+        s = isscalar(a)
+        assert s is True,'%r is not a scalar' % type(a)
+
     def test_reduce_nd(self):
         from numpy import arange, array
         a = arange(15).reshape(5, 3)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -288,10 +288,8 @@
 
         _, dtype, _ = self.find_specialization(space, dtype, dtype, out,
                                                    casting='unsafe')
-        call__array_wrap__ = True
         if shapelen == len(axes):
             if out:
-                call__array_wrap__ = False
                 if out.ndims() > 0:
                     raise oefmt(space.w_ValueError,
                                 "output parameter for reduction operation %s 
has "
@@ -302,15 +300,20 @@
             if out:
                 out.set_scalar_value(res)
                 return out
+            w_NDimArray = space.gettypefor(W_NDimArray)
+            call__array_wrap__ = False
             if keepdims:
                 shape = [1] * len(obj_shape)
                 out = W_NDimArray.from_shape(space, shape, dtype, 
w_instance=obj)
                 out.implementation.setitem(0, res)
+                call__array_wrap__ = True
                 res = out
-            elif not space.is_w(space.type(w_obj), 
space.gettypefor(W_NDimArray)):
+            elif (space.issubtype_w(space.type(w_obj), w_NDimArray) and 
+                  not space.is_w(space.type(w_obj), w_NDimArray)):
                 # subtypes return a ndarray subtype, not a scalar
                 out = W_NDimArray.from_shape(space, [1], dtype, w_instance=obj)
                 out.implementation.setitem(0, res)
+                call__array_wrap__ = True
                 res = out
             if call__array_wrap__:
                 res = space.call_method(obj, '__array_wrap__', res, 
space.w_None)
@@ -359,8 +362,7 @@
                 return out
             loop.reduce(
                 space, self.func, obj, axis_flags, dtype, out, self.identity)
-            if call__array_wrap__:
-                out = space.call_method(obj, '__array_wrap__', out, 
space.w_None)
+            out = space.call_method(obj, '__array_wrap__', out, space.w_None)
             return out
 
     def descr_outer(self, space, args_w):
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to