Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77599:4743d2084e4e
Date: 2015-05-26 20:25 +0100
http://bitbucket.org/pypy/pypy/changeset/4743d2084e4e/

Log:    precompute W_Ufunc2.allowed_types()

diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -428,7 +428,7 @@
         return casting
 
 class W_Ufunc1(W_Ufunc):
-    _immutable_fields_ = ["func", "bool_result", "dtypes"]
+    _immutable_fields_ = ["func", "bool_result", "dtypes[*]"]
     nin = 1
     nout = 1
     nargs = 2
@@ -514,7 +514,8 @@
 
 
 class W_Ufunc2(W_Ufunc):
-    _immutable_fields_ = ["func", "bool_result", "done_func", "simple_binary"]
+    _immutable_fields_ = ["func", "bool_result", "done_func", "dtypes[*]",
+                          "simple_binary"]
     nin = 2
     nout = 1
     nargs = 3
@@ -665,14 +666,14 @@
         """Find a valid dtype signature of the form xx->x"""
         if dtype.is_object():
             return dtype
-        for dt_in, dt_out in self.allowed_types(space):
+        for dt_in, dt_out in self.dtypes:
             if dtype.can_cast_to(dt_in):
                 if dt_out == dt_in:
                     return dt_in
                 else:
                     dtype = dt_out
                     break
-        for dt_in, dt_out in self.allowed_types(space):
+        for dt_in, dt_out in self.dtypes:
             if dtype.can_cast_to(dt_in) and dt_out == dt_in:
                 return dt_in
         raise ValueError(
@@ -686,7 +687,7 @@
             dtype = get_dtype_cache(space).w_objectdtype
             return dtype, dtype
         in_casting = safe_casting_mode(casting)
-        for dt_in, dt_out in self.allowed_types(space):
+        for dt_in, dt_out in self.dtypes:
             if use_min_scalar:
                 if not can_cast_array(space, w_arg, dt_in, in_casting):
                     continue
@@ -704,30 +705,6 @@
             raise oefmt(space.w_TypeError,
                 "ufunc '%s' not supported for the input types", self.name)
 
-    def allowed_types(self, space):
-        dtypes = []
-        cache = get_dtype_cache(space)
-        if not self.promote_bools and not self.promote_to_float:
-            dtypes.append((cache.w_booldtype, cache.w_booldtype))
-        if not self.promote_to_float:
-            for dt in cache.integer_dtypes:
-                dtypes.append((dt, dt))
-        if not self.int_only:
-            for dt in cache.float_dtypes:
-                dtypes.append((dt, dt))
-            for dt in cache.complex_dtypes:
-                if self.complex_to_float:
-                    if dt.num == NPY.CFLOAT:
-                        dt_out = get_dtype_cache(space).w_float32dtype
-                    else:
-                        dt_out = get_dtype_cache(space).w_float64dtype
-                    dtypes.append((dt, dt_out))
-                else:
-                    dtypes.append((dt, dt))
-        if self.bool_result:
-            dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
-        return dtypes
-
 
 
 class W_UfuncGeneric(W_Ufunc):
@@ -1310,7 +1287,7 @@
         if nin == 1:
             ufunc = unary_ufunc(space, func, ufunc_name, **extra_kwargs)
         elif nin == 2:
-            ufunc = W_Ufunc2(func, ufunc_name, **extra_kwargs)
+            ufunc = binary_ufunc(space, func, ufunc_name, **extra_kwargs)
         setattr(self, ufunc_name, ufunc)
 
 def unary_ufunc(space, func, ufunc_name, **kwargs):
@@ -1342,6 +1319,35 @@
         dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
     return dtypes
 
+def binary_ufunc(space, func, ufunc_name, **kwargs):
+    ufunc = W_Ufunc2(func, ufunc_name, **kwargs)
+    ufunc.dtypes = _ufunc2_dtypes(ufunc, space)
+    return ufunc
+
+def _ufunc2_dtypes(ufunc, space):
+    dtypes = []
+    cache = get_dtype_cache(space)
+    if not ufunc.promote_bools and not ufunc.promote_to_float:
+        dtypes.append((cache.w_booldtype, cache.w_booldtype))
+    if not ufunc.promote_to_float:
+        for dt in cache.integer_dtypes:
+            dtypes.append((dt, dt))
+    if not ufunc.int_only:
+        for dt in cache.float_dtypes:
+            dtypes.append((dt, dt))
+        for dt in cache.complex_dtypes:
+            if ufunc.complex_to_float:
+                if dt.num == NPY.CFLOAT:
+                    dt_out = get_dtype_cache(space).w_float32dtype
+                else:
+                    dt_out = get_dtype_cache(space).w_float64dtype
+                dtypes.append((dt, dt_out))
+            else:
+                dtypes.append((dt, dt))
+    if ufunc.bool_result:
+        dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
+    return dtypes
+
 
 def get(space):
     return space.fromcache(UfuncState)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to