Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77357:e7580d87a79a
Date: 2015-05-17 18:47 +0100
http://bitbucket.org/pypy/pypy/changeset/e7580d87a79a/
Log: push more logic inside allowed_types()
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -439,7 +439,7 @@
raise oefmt(space.w_TypeError, 'output must be an array')
w_obj = numpify(space, w_obj)
dtype = w_obj.get_dtype(space)
- calc_dtype, res_dtype, func = self.find_specialization(space, dtype,
out, casting)
+ calc_dtype, dt_out, func = self.find_specialization(space, dtype, out,
casting)
if isinstance(w_obj, W_GenericBox):
if out is None:
return self.call_scalar(space, w_obj, calc_dtype)
@@ -450,7 +450,7 @@
broadcast_down=False)
if out is None:
w_res = W_NDimArray.from_shape(
- space, shape, res_dtype, w_instance=w_obj)
+ space, shape, dt_out, w_instance=w_obj)
else:
w_res = out
w_res = loop.call1(space, shape, func, calc_dtype, w_obj, w_res)
@@ -469,47 +469,30 @@
def find_specialization(self, space, dtype, out, casting):
if dtype.is_flexible():
raise oefmt(space.w_TypeError, 'Not implemented for this type')
- if (self.int_only and not (dtype.is_int() or dtype.is_object()) or
- not self.allow_bool and dtype.is_bool() or
+ if (not self.allow_bool and dtype.is_bool() or
not self.allow_complex and dtype.is_complex()):
raise oefmt(space.w_TypeError,
"ufunc %s not supported for the input type", self.name)
dt_in, dt_out = self._calc_dtype(space, dtype, out, casting)
-
- if out is not None:
- res_dtype = out.get_dtype()
- #if not w_obj.get_dtype().can_cast_to(res_dtype):
- # raise oefmt(space.w_TypeError,
- # "Cannot cast ufunc %s output from dtype('%s') to
dtype('%s') with casting rule 'same_kind'", self.name, w_obj.get_dtype().name,
res_dtype.name)
- elif self.bool_result:
- res_dtype = get_dtype_cache(space).w_booldtype
- else:
- res_dtype = dt_in
- if self.complex_to_float and dt_in.is_complex():
- if dt_in.num == NPY.CFLOAT:
- res_dtype = get_dtype_cache(space).w_float32dtype
- else:
- res_dtype = get_dtype_cache(space).w_float64dtype
- return dt_in, res_dtype, self.func
+ return dt_in, dt_out, self.func
def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'):
use_min_scalar = False
if arg_dtype.is_object():
return arg_dtype, arg_dtype
in_casting = safe_casting_mode(casting)
- for dtype in self.allowed_types(space):
+ for dt_in, dt_out in self.allowed_types(space):
if use_min_scalar:
- if not can_cast_array(space, w_arg, dtype, in_casting):
+ if not can_cast_array(space, w_arg, dt_in, in_casting):
continue
else:
- if not can_cast_type(space, arg_dtype, dtype, in_casting):
+ if not can_cast_type(space, arg_dtype, dt_in, in_casting):
continue
- dt_out = dtype
if out is not None:
res_dtype = out.get_dtype()
if not can_cast_type(space, dt_out, res_dtype, casting):
continue
- return dtype, dt_out
+ return dt_in, dt_out
else:
raise oefmt(space.w_TypeError,
@@ -520,11 +503,24 @@
dtypes = []
cache = get_dtype_cache(space)
if not self.promote_bools and not self.promote_to_float:
- dtypes.append(cache.w_booldtype)
+ dtypes.append((cache.w_booldtype, cache.w_booldtype))
if not self.promote_to_float:
- dtypes.extend(cache.integer_dtypes)
- dtypes.extend(cache.float_dtypes)
- dtypes.extend(cache.complex_dtypes)
+ for dt in cache.integer_dtypes:
+ dtypes.append((dt, dt))
+ if not self.int_only:
+ for dt in cache.float_dtypes:
+ dtypes.append((dt, dt))
+ for dt in cache.complex_dtypes:
+ if self.complex_to_float:
+ if dt.num == NPY.CFLOAT:
+ dt_out = get_dtype_cache(space).w_float32dtype
+ else:
+ dt_out = get_dtype_cache(space).w_float64dtype
+ dtypes.append((dt, dt_out))
+ else:
+ dtypes.append((dt, dt))
+ if self.bool_result:
+ dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
return dtypes
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit