Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77344:4e2f77ed8e96
Date: 2015-05-16 16:57 +0100
http://bitbucket.org/pypy/pypy/changeset/4e2f77ed8e96/
Log: Let W_Ufunc2 handle scalars in the same way as W_Ufunc1
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -425,6 +425,11 @@
c = add(a, b)
for i in range(3):
assert c[i] == a[i] + b[i]
+ class Obj(object):
+ def __add__(self, other):
+ return 'add'
+ x = Obj()
+ assert type(add(x, 0)) is str
def test_divide(self):
from numpy import array, divide
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -11,7 +11,7 @@
from rpython.rtyper.lltypesystem import rffi, lltype
from rpython.rlib.objectmodel import keepalive_until_here
-from pypy.module.micronumpy import boxes, loop, constants as NPY
+from pypy.module.micronumpy import loop, constants as NPY
from pypy.module.micronumpy.descriptor import get_dtype_cache, decode_w_dtype
from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
from pypy.module.micronumpy.ctors import numpify
@@ -442,7 +442,7 @@
calc_dtype, res_dtype, func = self.find_specialization(space, dtype,
out, casting)
if isinstance(w_obj, W_GenericBox):
if out is None:
- return self.call_scalar(space, w_obj, calc_dtype, res_dtype)
+ return self.call_scalar(space, w_obj, calc_dtype)
else:
w_obj = W_NDimArray.from_scalar(space, w_obj)
assert isinstance(w_obj, W_NDimArray)
@@ -460,7 +460,7 @@
w_res = space.call_method(w_obj, '__array_wrap__', w_res)
return w_res
- def call_scalar(self, space, w_arg, in_dtype, out_dtype):
+ def call_scalar(self, space, w_arg, in_dtype):
w_val = self.func(in_dtype, w_arg.convert_to(space, in_dtype))
if isinstance(w_val, W_ObjectBox):
return w_val.w_obj
@@ -637,43 +637,38 @@
res_dtype = get_dtype_cache(space).w_booldtype
else:
res_dtype = calc_dtype
- if w_lhs.is_scalar() and w_rhs.is_scalar():
- return self.call_scalar(space,
- w_lhs.get_scalar_value(),
- w_rhs.get_scalar_value(),
- calc_dtype, res_dtype, out)
- if isinstance(w_lhs, boxes.W_GenericBox):
+ if (isinstance(w_lhs, W_GenericBox) and
+ isinstance(w_rhs, W_GenericBox) and out is None):
+ return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
+ if isinstance(w_lhs, W_GenericBox):
w_lhs = W_NDimArray.from_scalar(space, w_lhs)
assert isinstance(w_lhs, W_NDimArray)
- if isinstance(w_rhs, boxes.W_GenericBox):
+ if isinstance(w_rhs, W_GenericBox):
w_rhs = W_NDimArray.from_scalar(space, w_rhs)
assert isinstance(w_rhs, W_NDimArray)
new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
new_shape = shape_agreement(space, new_shape, out,
broadcast_down=False)
w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs)
if out is None:
- w_ret = W_NDimArray.from_shape(space, new_shape, res_dtype,
+ w_res = W_NDimArray.from_shape(space, new_shape, res_dtype,
w_instance=out_subtype)
else:
- w_ret = out
- w_ret = loop.call2(space, new_shape, self.func, calc_dtype,
- w_lhs, w_rhs, w_ret)
+ w_res = out
+ w_res = loop.call2(space, new_shape, self.func, calc_dtype,
+ w_lhs, w_rhs, w_res)
if out is None:
- w_ret = space.call_method(w_highpriority, '__array_wrap__', w_ret)
- return w_ret
+ if w_res.is_scalar():
+ return w_res.get_scalar_value()
+ w_res = space.call_method(w_highpriority, '__array_wrap__', w_res)
+ return w_res
- def call_scalar(self, space, w_lhs, w_rhs, in_dtype, out_dtype, out):
+ def call_scalar(self, space, w_lhs, w_rhs, in_dtype):
w_val = self.func(in_dtype,
w_lhs.convert_to(space, in_dtype),
w_rhs.convert_to(space, in_dtype))
- if out is None:
- return w_val
- w_val = out_dtype.coerce(space, w_val)
- if out.is_scalar():
- out.set_scalar_value(w_val)
- else:
- out.fill(space, w_val)
- return out
+ if isinstance(w_val, W_ObjectBox):
+ return w_val.w_obj
+ return w_val
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit