Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77344:4e2f77ed8e96
Date: 2015-05-16 16:57 +0100
http://bitbucket.org/pypy/pypy/changeset/4e2f77ed8e96/

Log:    Let W_Ufunc2 handle scalars in the same way as W_Ufunc1

diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -425,6 +425,11 @@
         c = add(a, b)
         for i in range(3):
             assert c[i] == a[i] + b[i]
+        class Obj(object):
+            def __add__(self, other):
+                return 'add'
+        x = Obj()
+        assert type(add(x, 0)) is str
 
     def test_divide(self):
         from numpy import array, divide
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -11,7 +11,7 @@
 from rpython.rtyper.lltypesystem import rffi, lltype
 from rpython.rlib.objectmodel import keepalive_until_here
 
-from pypy.module.micronumpy import boxes, loop, constants as NPY
+from pypy.module.micronumpy import loop, constants as NPY
 from pypy.module.micronumpy.descriptor import get_dtype_cache, decode_w_dtype
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
 from pypy.module.micronumpy.ctors import numpify
@@ -442,7 +442,7 @@
         calc_dtype, res_dtype, func = self.find_specialization(space, dtype, 
out, casting)
         if isinstance(w_obj, W_GenericBox):
             if out is None:
-                return self.call_scalar(space, w_obj, calc_dtype, res_dtype)
+                return self.call_scalar(space, w_obj, calc_dtype)
             else:
                 w_obj = W_NDimArray.from_scalar(space, w_obj)
         assert isinstance(w_obj, W_NDimArray)
@@ -460,7 +460,7 @@
             w_res = space.call_method(w_obj, '__array_wrap__', w_res)
         return w_res
 
-    def call_scalar(self, space, w_arg, in_dtype, out_dtype):
+    def call_scalar(self, space, w_arg, in_dtype):
         w_val = self.func(in_dtype, w_arg.convert_to(space, in_dtype))
         if isinstance(w_val, W_ObjectBox):
             return w_val.w_obj
@@ -637,43 +637,38 @@
             res_dtype = get_dtype_cache(space).w_booldtype
         else:
             res_dtype = calc_dtype
-        if w_lhs.is_scalar() and w_rhs.is_scalar():
-            return self.call_scalar(space,
-                                    w_lhs.get_scalar_value(),
-                                    w_rhs.get_scalar_value(),
-                                    calc_dtype, res_dtype, out)
-        if isinstance(w_lhs, boxes.W_GenericBox):
+        if (isinstance(w_lhs, W_GenericBox) and
+                isinstance(w_rhs, W_GenericBox) and out is None):
+            return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
+        if isinstance(w_lhs, W_GenericBox):
             w_lhs = W_NDimArray.from_scalar(space, w_lhs)
         assert isinstance(w_lhs, W_NDimArray)
-        if isinstance(w_rhs, boxes.W_GenericBox):
+        if isinstance(w_rhs, W_GenericBox):
             w_rhs = W_NDimArray.from_scalar(space, w_rhs)
         assert isinstance(w_rhs, W_NDimArray)
         new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
         new_shape = shape_agreement(space, new_shape, out, 
broadcast_down=False)
         w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs)
         if out is None:
-            w_ret = W_NDimArray.from_shape(space, new_shape, res_dtype,
+            w_res = W_NDimArray.from_shape(space, new_shape, res_dtype,
                                            w_instance=out_subtype)
         else:
-            w_ret = out
-        w_ret = loop.call2(space, new_shape, self.func, calc_dtype,
-                           w_lhs, w_rhs, w_ret)
+            w_res = out
+        w_res = loop.call2(space, new_shape, self.func, calc_dtype,
+                           w_lhs, w_rhs, w_res)
         if out is None:
-            w_ret = space.call_method(w_highpriority, '__array_wrap__', w_ret)
-        return w_ret
+            if w_res.is_scalar():
+                return w_res.get_scalar_value()
+            w_res = space.call_method(w_highpriority, '__array_wrap__', w_res)
+        return w_res
 
-    def call_scalar(self, space, w_lhs, w_rhs, in_dtype, out_dtype, out):
+    def call_scalar(self, space, w_lhs, w_rhs, in_dtype):
         w_val = self.func(in_dtype,
                           w_lhs.convert_to(space, in_dtype),
                           w_rhs.convert_to(space, in_dtype))
-        if out is None:
-            return w_val
-        w_val = out_dtype.coerce(space, w_val)
-        if out.is_scalar():
-            out.set_scalar_value(w_val)
-        else:
-            out.fill(space, w_val)
-        return out
+        if isinstance(w_val, W_ObjectBox):
+            return w_val.w_obj
+        return w_val
 
 
 
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to