Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77599:4743d2084e4e
Date: 2015-05-26 20:25 +0100
http://bitbucket.org/pypy/pypy/changeset/4743d2084e4e/
Log: precompute W_Ufunc2.allowed_types()
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -428,7 +428,7 @@
return casting
class W_Ufunc1(W_Ufunc):
- _immutable_fields_ = ["func", "bool_result", "dtypes"]
+ _immutable_fields_ = ["func", "bool_result", "dtypes[*]"]
nin = 1
nout = 1
nargs = 2
@@ -514,7 +514,8 @@
class W_Ufunc2(W_Ufunc):
- _immutable_fields_ = ["func", "bool_result", "done_func", "simple_binary"]
+ _immutable_fields_ = ["func", "bool_result", "done_func", "dtypes[*]",
+ "simple_binary"]
nin = 2
nout = 1
nargs = 3
@@ -665,14 +666,14 @@
"""Find a valid dtype signature of the form xx->x"""
if dtype.is_object():
return dtype
- for dt_in, dt_out in self.allowed_types(space):
+ for dt_in, dt_out in self.dtypes:
if dtype.can_cast_to(dt_in):
if dt_out == dt_in:
return dt_in
else:
dtype = dt_out
break
- for dt_in, dt_out in self.allowed_types(space):
+ for dt_in, dt_out in self.dtypes:
if dtype.can_cast_to(dt_in) and dt_out == dt_in:
return dt_in
raise ValueError(
@@ -686,7 +687,7 @@
dtype = get_dtype_cache(space).w_objectdtype
return dtype, dtype
in_casting = safe_casting_mode(casting)
- for dt_in, dt_out in self.allowed_types(space):
+ for dt_in, dt_out in self.dtypes:
if use_min_scalar:
if not can_cast_array(space, w_arg, dt_in, in_casting):
continue
@@ -704,30 +705,6 @@
raise oefmt(space.w_TypeError,
"ufunc '%s' not supported for the input types", self.name)
- def allowed_types(self, space):
- dtypes = []
- cache = get_dtype_cache(space)
- if not self.promote_bools and not self.promote_to_float:
- dtypes.append((cache.w_booldtype, cache.w_booldtype))
- if not self.promote_to_float:
- for dt in cache.integer_dtypes:
- dtypes.append((dt, dt))
- if not self.int_only:
- for dt in cache.float_dtypes:
- dtypes.append((dt, dt))
- for dt in cache.complex_dtypes:
- if self.complex_to_float:
- if dt.num == NPY.CFLOAT:
- dt_out = get_dtype_cache(space).w_float32dtype
- else:
- dt_out = get_dtype_cache(space).w_float64dtype
- dtypes.append((dt, dt_out))
- else:
- dtypes.append((dt, dt))
- if self.bool_result:
- dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
- return dtypes
-
class W_UfuncGeneric(W_Ufunc):
@@ -1310,7 +1287,7 @@
if nin == 1:
ufunc = unary_ufunc(space, func, ufunc_name, **extra_kwargs)
elif nin == 2:
- ufunc = W_Ufunc2(func, ufunc_name, **extra_kwargs)
+ ufunc = binary_ufunc(space, func, ufunc_name, **extra_kwargs)
setattr(self, ufunc_name, ufunc)
def unary_ufunc(space, func, ufunc_name, **kwargs):
@@ -1342,6 +1319,35 @@
dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
return dtypes
+def binary_ufunc(space, func, ufunc_name, **kwargs):
+ ufunc = W_Ufunc2(func, ufunc_name, **kwargs)
+ ufunc.dtypes = _ufunc2_dtypes(ufunc, space)
+ return ufunc
+
+def _ufunc2_dtypes(ufunc, space):
+ dtypes = []
+ cache = get_dtype_cache(space)
+ if not ufunc.promote_bools and not ufunc.promote_to_float:
+ dtypes.append((cache.w_booldtype, cache.w_booldtype))
+ if not ufunc.promote_to_float:
+ for dt in cache.integer_dtypes:
+ dtypes.append((dt, dt))
+ if not ufunc.int_only:
+ for dt in cache.float_dtypes:
+ dtypes.append((dt, dt))
+ for dt in cache.complex_dtypes:
+ if ufunc.complex_to_float:
+ if dt.num == NPY.CFLOAT:
+ dt_out = get_dtype_cache(space).w_float32dtype
+ else:
+ dt_out = get_dtype_cache(space).w_float64dtype
+ dtypes.append((dt, dt_out))
+ else:
+ dtypes.append((dt, dt))
+ if ufunc.bool_result:
+ dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
+ return dtypes
+
def get(space):
return space.fromcache(UfuncState)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit