Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77287:a819ee693791
Date: 2015-05-10 21:16 +0100
http://bitbucket.org/pypy/pypy/changeset/a819ee693791/

Log:    extract method W_Ufunc{1,2}.call_scalar()

diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -366,24 +366,27 @@
                 else:
                     res_dtype = get_dtype_cache(space).w_float64dtype
         if w_obj.is_scalar():
-            w_val = self.func(calc_dtype,
-                              w_obj.get_scalar_value().convert_to(space, 
calc_dtype))
-            if out is None:
-                if res_dtype.is_object():
-                    w_val = w_obj.get_scalar_value()
-                return w_val
-            w_val = res_dtype.coerce(space, w_val)
-            if out.is_scalar():
-                out.set_scalar_value(w_val)
-            else:
-                out.fill(space, w_val)
-            return out
+            return self.call_scalar(space, w_obj.get_scalar_value(),
+                                    calc_dtype, res_dtype, out)
         assert isinstance(w_obj, W_NDimArray)
         shape = shape_agreement(space, w_obj.get_shape(), out,
                                 broadcast_down=False)
         return loop.call1(space, shape, self.func, calc_dtype, res_dtype,
                           w_obj, out)
 
+    def call_scalar(self, space, w_arg, in_dtype, out_dtype, out):
+        w_val = self.func(in_dtype, w_arg.convert_to(space, in_dtype))
+        if out is None:
+            if out_dtype.is_object():
+                w_val = w_arg
+            return w_val
+        w_val = out_dtype.coerce(space, w_val)
+        if out.is_scalar():
+            out.set_scalar_value(w_val)
+        else:
+            out.fill(space, w_val)
+        return out
+
 
 class W_Ufunc2(W_Ufunc):
     _immutable_fields_ = ["func", "comparison_func", "done_func"]
@@ -486,6 +489,10 @@
         else:
             res_dtype = calc_dtype
         if w_lhs.is_scalar() and w_rhs.is_scalar():
+            return self.call_scalar(space,
+                                    w_lhs.get_scalar_value(),
+                                    w_rhs.get_scalar_value(),
+                                    calc_dtype, res_dtype, out)
             arr = self.func(calc_dtype,
                 w_lhs.get_scalar_value().convert_to(space, calc_dtype),
                 w_rhs.get_scalar_value().convert_to(space, calc_dtype)
@@ -509,6 +516,20 @@
         return loop.call2(space, new_shape, self.func, calc_dtype,
                           res_dtype, w_lhs, w_rhs, out)
 
+    def call_scalar(self, space, w_lhs, w_rhs, in_dtype, out_dtype, out):
+        w_val = self.func(in_dtype,
+                          w_lhs.convert_to(space, in_dtype),
+                          w_rhs.convert_to(space, in_dtype))
+        if out is None:
+            return w_val
+        w_val = out_dtype.coerce(space, w_val)
+        if out.is_scalar():
+            out.set_scalar_value(w_val)
+        else:
+            out.fill(space, w_val)
+        return out
+
+
 
 class W_UfuncGeneric(W_Ufunc):
     '''
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to