Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: use_min_scalar
Changeset: r77748:5b71a45fc55b
Date: 2015-06-01 20:22 +0100
http://bitbucket.org/pypy/pypy/changeset/5b71a45fc55b/

Log:    correct handling of scalars for simple binary ufuncs

diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -640,7 +640,14 @@
             return w_val.w_obj
         return w_val
 
-    def _find_specialization(self, space, l_dtype, r_dtype, out, casting):
+    def _find_specialization(self, space, l_dtype, r_dtype, out, casting,
+                             w_arg1, w_arg2):
+        if (self.are_common_types(l_dtype, r_dtype) and
+                w_arg1 is not None and w_arg2 is not None):
+            if not w_arg1.is_scalar() and w_arg2.is_scalar():
+                r_dtype = l_dtype
+            elif w_arg1.is_scalar() and not w_arg2.is_scalar():
+                l_dtype = r_dtype
         if (not self.allow_bool and (l_dtype.is_bool() or
                                          r_dtype.is_bool()) or
                 not self.allow_complex and (l_dtype.is_complex() or
@@ -657,17 +664,17 @@
 
     def find_specialization(self, space, l_dtype, r_dtype, out, casting,
                             w_arg1=None, w_arg2=None):
-        if (self.are_common_types(l_dtype, r_dtype) and
-                w_arg1 is not None and w_arg2 is not None):
-            if not w_arg1.is_scalar() and w_arg2.is_scalar():
-                r_dtype = l_dtype
-            elif w_arg1.is_scalar() and not w_arg2.is_scalar():
-                l_dtype = r_dtype
         if self.simple_binary:
             if out is None and not (l_dtype.is_object() or 
r_dtype.is_object()):
-                dtype = promote_types(space, l_dtype, r_dtype)
+                if w_arg1 is not None and w_arg2 is not None:
+                    w_arg1 = convert_to_array(space, w_arg1)
+                    w_arg2 = convert_to_array(space, w_arg2)
+                    dtype = find_result_type(space, [w_arg1, w_arg2], [])
+                else:
+                    dtype = promote_types(space, l_dtype, r_dtype)
                 return dtype, dtype, self.func
-        return self._find_specialization(space, l_dtype, r_dtype, out, casting)
+        return self._find_specialization(
+            space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2)
 
     def find_binop_type(self, space, dtype):
         """Find a valid dtype signature of the form xx->x"""
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to