Author: Ronan Lamy <ronan.l...@gmail.com> Branch: fix-result-types Changeset: r77357:e7580d87a79a Date: 2015-05-17 18:47 +0100 http://bitbucket.org/pypy/pypy/changeset/e7580d87a79a/
Log: push more logic inside allowed_types() diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -439,7 +439,7 @@ raise oefmt(space.w_TypeError, 'output must be an array') w_obj = numpify(space, w_obj) dtype = w_obj.get_dtype(space) - calc_dtype, res_dtype, func = self.find_specialization(space, dtype, out, casting) + calc_dtype, dt_out, func = self.find_specialization(space, dtype, out, casting) if isinstance(w_obj, W_GenericBox): if out is None: return self.call_scalar(space, w_obj, calc_dtype) @@ -450,7 +450,7 @@ broadcast_down=False) if out is None: w_res = W_NDimArray.from_shape( - space, shape, res_dtype, w_instance=w_obj) + space, shape, dt_out, w_instance=w_obj) else: w_res = out w_res = loop.call1(space, shape, func, calc_dtype, w_obj, w_res) @@ -469,47 +469,30 @@ def find_specialization(self, space, dtype, out, casting): if dtype.is_flexible(): raise oefmt(space.w_TypeError, 'Not implemented for this type') - if (self.int_only and not (dtype.is_int() or dtype.is_object()) or - not self.allow_bool and dtype.is_bool() or + if (not self.allow_bool and dtype.is_bool() or not self.allow_complex and dtype.is_complex()): raise oefmt(space.w_TypeError, "ufunc %s not supported for the input type", self.name) dt_in, dt_out = self._calc_dtype(space, dtype, out, casting) - - if out is not None: - res_dtype = out.get_dtype() - #if not w_obj.get_dtype().can_cast_to(res_dtype): - # raise oefmt(space.w_TypeError, - # "Cannot cast ufunc %s output from dtype('%s') to dtype('%s') with casting rule 'same_kind'", self.name, w_obj.get_dtype().name, res_dtype.name) - elif self.bool_result: - res_dtype = get_dtype_cache(space).w_booldtype - else: - res_dtype = dt_in - if self.complex_to_float and dt_in.is_complex(): - if dt_in.num == NPY.CFLOAT: - res_dtype = get_dtype_cache(space).w_float32dtype - else: - res_dtype = get_dtype_cache(space).w_float64dtype - return dt_in, res_dtype, self.func + return dt_in, dt_out, self.func def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'): use_min_scalar = False if arg_dtype.is_object(): return arg_dtype, arg_dtype in_casting = safe_casting_mode(casting) - for dtype in self.allowed_types(space): + for dt_in, dt_out in self.allowed_types(space): if use_min_scalar: - if not can_cast_array(space, w_arg, dtype, in_casting): + if not can_cast_array(space, w_arg, dt_in, in_casting): continue else: - if not can_cast_type(space, arg_dtype, dtype, in_casting): + if not can_cast_type(space, arg_dtype, dt_in, in_casting): continue - dt_out = dtype if out is not None: res_dtype = out.get_dtype() if not can_cast_type(space, dt_out, res_dtype, casting): continue - return dtype, dt_out + return dt_in, dt_out else: raise oefmt(space.w_TypeError, @@ -520,11 +503,24 @@ dtypes = [] cache = get_dtype_cache(space) if not self.promote_bools and not self.promote_to_float: - dtypes.append(cache.w_booldtype) + dtypes.append((cache.w_booldtype, cache.w_booldtype)) if not self.promote_to_float: - dtypes.extend(cache.integer_dtypes) - dtypes.extend(cache.float_dtypes) - dtypes.extend(cache.complex_dtypes) + for dt in cache.integer_dtypes: + dtypes.append((dt, dt)) + if not self.int_only: + for dt in cache.float_dtypes: + dtypes.append((dt, dt)) + for dt in cache.complex_dtypes: + if self.complex_to_float: + if dt.num == NPY.CFLOAT: + dt_out = get_dtype_cache(space).w_float32dtype + else: + dt_out = get_dtype_cache(space).w_float64dtype + dtypes.append((dt, dt_out)) + else: + dtypes.append((dt, dt)) + if self.bool_result: + dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes] return dtypes _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit