Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: fix-result-types
Changeset: r77357:e7580d87a79a
Date: 2015-05-17 18:47 +0100
http://bitbucket.org/pypy/pypy/changeset/e7580d87a79a/

Log:    push more logic inside allowed_types()

diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -439,7 +439,7 @@
                 raise oefmt(space.w_TypeError, 'output must be an array')
         w_obj = numpify(space, w_obj)
         dtype = w_obj.get_dtype(space)
-        calc_dtype, res_dtype, func = self.find_specialization(space, dtype, 
out, casting)
+        calc_dtype, dt_out, func = self.find_specialization(space, dtype, out, 
casting)
         if isinstance(w_obj, W_GenericBox):
             if out is None:
                 return self.call_scalar(space, w_obj, calc_dtype)
@@ -450,7 +450,7 @@
                                 broadcast_down=False)
         if out is None:
             w_res = W_NDimArray.from_shape(
-                space, shape, res_dtype, w_instance=w_obj)
+                space, shape, dt_out, w_instance=w_obj)
         else:
             w_res = out
         w_res = loop.call1(space, shape, func, calc_dtype, w_obj, w_res)
@@ -469,47 +469,30 @@
     def find_specialization(self, space, dtype, out, casting):
         if dtype.is_flexible():
             raise oefmt(space.w_TypeError, 'Not implemented for this type')
-        if (self.int_only and not (dtype.is_int() or dtype.is_object()) or
-                not self.allow_bool and dtype.is_bool() or
+        if (not self.allow_bool and dtype.is_bool() or
                 not self.allow_complex and dtype.is_complex()):
             raise oefmt(space.w_TypeError,
                 "ufunc %s not supported for the input type", self.name)
         dt_in, dt_out = self._calc_dtype(space, dtype, out, casting)
-
-        if out is not None:
-            res_dtype = out.get_dtype()
-            #if not w_obj.get_dtype().can_cast_to(res_dtype):
-            #    raise oefmt(space.w_TypeError,
-            #        "Cannot cast ufunc %s output from dtype('%s') to 
dtype('%s') with casting rule 'same_kind'", self.name, w_obj.get_dtype().name, 
res_dtype.name)
-        elif self.bool_result:
-            res_dtype = get_dtype_cache(space).w_booldtype
-        else:
-            res_dtype = dt_in
-            if self.complex_to_float and dt_in.is_complex():
-                if dt_in.num == NPY.CFLOAT:
-                    res_dtype = get_dtype_cache(space).w_float32dtype
-                else:
-                    res_dtype = get_dtype_cache(space).w_float64dtype
-        return dt_in, res_dtype, self.func
+        return dt_in, dt_out, self.func
 
     def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'):
         use_min_scalar = False
         if arg_dtype.is_object():
             return arg_dtype, arg_dtype
         in_casting = safe_casting_mode(casting)
-        for dtype in self.allowed_types(space):
+        for dt_in, dt_out in self.allowed_types(space):
             if use_min_scalar:
-                if not can_cast_array(space, w_arg, dtype, in_casting):
+                if not can_cast_array(space, w_arg, dt_in, in_casting):
                     continue
             else:
-                if not can_cast_type(space, arg_dtype, dtype, in_casting):
+                if not can_cast_type(space, arg_dtype, dt_in, in_casting):
                     continue
-            dt_out = dtype
             if out is not None:
                 res_dtype = out.get_dtype()
                 if not can_cast_type(space, dt_out, res_dtype, casting):
                     continue
-            return dtype, dt_out
+            return dt_in, dt_out
 
         else:
             raise oefmt(space.w_TypeError,
@@ -520,11 +503,24 @@
         dtypes = []
         cache = get_dtype_cache(space)
         if not self.promote_bools and not self.promote_to_float:
-            dtypes.append(cache.w_booldtype)
+            dtypes.append((cache.w_booldtype, cache.w_booldtype))
         if not self.promote_to_float:
-            dtypes.extend(cache.integer_dtypes)
-        dtypes.extend(cache.float_dtypes)
-        dtypes.extend(cache.complex_dtypes)
+            for dt in cache.integer_dtypes:
+                dtypes.append((dt, dt))
+        if not self.int_only:
+            for dt in cache.float_dtypes:
+                dtypes.append((dt, dt))
+            for dt in cache.complex_dtypes:
+                if self.complex_to_float:
+                    if dt.num == NPY.CFLOAT:
+                        dt_out = get_dtype_cache(space).w_float32dtype
+                    else:
+                        dt_out = get_dtype_cache(space).w_float64dtype
+                    dtypes.append((dt, dt_out))
+                else:
+                    dtypes.append((dt, dt))
+        if self.bool_result:
+            dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
         return dtypes
 
 
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to