Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77345:f0619db6ce92
Date: 2015-05-16 19:21 +0100
http://bitbucket.org/pypy/pypy/changeset/f0619db6ce92/

Log:    Create W_Ufunc2.find_specialization()

diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -562,10 +562,14 @@
     def call(self, space, args_w, sig, casting, extobj):
         w_obj = args_w[0]
         if len(args_w) > 2:
-            [w_lhs, w_rhs, w_out] = args_w
+            [w_lhs, w_rhs, out] = args_w
+            if space.is_none(out):
+                out = None
+            elif not isinstance(out, W_NDimArray):
+                raise oefmt(space.w_TypeError, 'output must be an array')
         else:
             [w_lhs, w_rhs] = args_w
-            w_out = None
+            out = None
         if not isinstance(w_rhs, W_NDimArray):
             # numpy implementation detail, useful for things like 
numpy.Polynomial
             # FAIL with NotImplemented if the other object has
@@ -585,12 +589,12 @@
                 self.bool_result:
             pass
         elif (w_ldtype.is_str()) and \
-                self.bool_result and w_out is None:
+                self.bool_result and out is None:
             if self.name in ('equal', 'less_equal', 'less'):
                return space.wrap(False)
             return space.wrap(True)
         elif (w_rdtype.is_str()) and \
-                self.bool_result and w_out is None:
+                self.bool_result and out is None:
             if self.name in ('not_equal','less', 'less_equal'):
                return space.wrap(True)
             return space.wrap(False)
@@ -613,30 +617,7 @@
                 w_rdtype = w_ldtype
             elif w_lhs.is_scalar() and not w_rhs.is_scalar():
                 w_ldtype = w_rdtype
-        calc_dtype = find_binop_result_dtype(space,
-            w_ldtype, w_rdtype,
-            promote_to_float=self.promote_to_float,
-            promote_bools=self.promote_bools)
-        if (self.int_only and (not (w_ldtype.is_int() or w_ldtype.is_object()) 
or
-                               not (w_rdtype.is_int() or w_rdtype.is_object()) 
or
-                               not (calc_dtype.is_int() or 
calc_dtype.is_object())) or
-                not self.allow_bool and (w_ldtype.is_bool() or
-                                         w_rdtype.is_bool()) or
-                not self.allow_complex and (w_ldtype.is_complex() or
-                                            w_rdtype.is_complex())):
-            raise oefmt(space.w_TypeError,
-                "ufunc '%s' not supported for the input types", self.name)
-        if space.is_none(w_out):
-            out = None
-        elif not isinstance(w_out, W_NDimArray):
-            raise oefmt(space.w_TypeError, 'output must be an array')
-        else:
-            out = w_out
-            calc_dtype = out.get_dtype()
-        if self.bool_result:
-            res_dtype = get_dtype_cache(space).w_booldtype
-        else:
-            res_dtype = calc_dtype
+        calc_dtype, res_dtype, func = self.find_specialization(space, 
w_ldtype, w_rdtype, out, casting)
         if (isinstance(w_lhs, W_GenericBox) and
                 isinstance(w_rhs, W_GenericBox) and out is None):
             return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
@@ -670,6 +651,28 @@
             return w_val.w_obj
         return w_val
 
+    def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+        calc_dtype = find_binop_result_dtype(space,
+            l_dtype, r_dtype,
+            promote_to_float=self.promote_to_float,
+            promote_bools=self.promote_bools)
+        if (self.int_only and (not (l_dtype.is_int() or l_dtype.is_object()) or
+                               not (r_dtype.is_int() or r_dtype.is_object()) or
+                               not (calc_dtype.is_int() or 
calc_dtype.is_object())) or
+                not self.allow_bool and (l_dtype.is_bool() or
+                                         r_dtype.is_bool()) or
+                not self.allow_complex and (l_dtype.is_complex() or
+                                            r_dtype.is_complex())):
+            raise oefmt(space.w_TypeError,
+                "ufunc '%s' not supported for the input types", self.name)
+        if out is not None:
+            calc_dtype = out.get_dtype()
+        if self.bool_result:
+            res_dtype = get_dtype_cache(space).w_booldtype
+        else:
+            res_dtype = calc_dtype
+        return calc_dtype, res_dtype, self.func
+
 
 
 class W_UfuncGeneric(W_Ufunc):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to