Author: Ronan Lamy <[email protected]>
Branch:
Changeset: r77796:cad6015d5380
Date: 2015-06-03 00:04 +0100
http://bitbucket.org/pypy/pypy/changeset/cad6015d5380/
Log: Merged use_min_scalar into default
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1349,3 +1349,6 @@
assert np.add(np.float16(0), np.longdouble(0)).dtype == np.longdouble
assert np.add(np.float16(0), np.complex64(0)).dtype == np.complex64
assert np.add(np.float16(0), np.complex128(0)).dtype == np.complex128
+ assert np.add(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
+ assert np.subtract(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
+ assert np.divide(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -21,7 +21,8 @@
from pypy.module.micronumpy.support import (_parse_signature, product,
get_storage_as_int, is_rhs_priority_higher)
from .casting import (
- can_cast_type, can_cast_to, find_result_type, promote_types)
+ can_cast_type, can_cast_array, can_cast_to,
+ find_result_type, promote_types)
from .boxes import W_GenericBox, W_ObjectBox
def done_if_true(dtype, val):
@@ -495,17 +496,12 @@
return dt_in, dt_out, self.func
def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'):
- use_min_scalar = False
if arg_dtype.is_object():
return arg_dtype, arg_dtype
in_casting = safe_casting_mode(casting)
for dt_in, dt_out in self.dtypes:
- if use_min_scalar:
- if not can_cast_array(space, w_arg, dt_in, in_casting):
- continue
- else:
- if not can_cast_type(space, arg_dtype, dt_in, in_casting):
- continue
+ if not can_cast_type(space, arg_dtype, dt_in, in_casting):
+ continue
if out is not None:
res_dtype = out.get_dtype()
if not can_cast_type(space, dt_out, res_dtype, casting):
@@ -605,21 +601,18 @@
w_rdtype.get_name(), w_ldtype.get_name(),
self.name)
- if self.are_common_types(w_ldtype, w_rdtype):
- if not w_lhs.is_scalar() and w_rhs.is_scalar():
- w_rdtype = w_ldtype
- elif w_lhs.is_scalar() and not w_rhs.is_scalar():
- w_ldtype = w_rdtype
- calc_dtype, dt_out, func = self.find_specialization(space, w_ldtype,
w_rdtype, out, casting)
if (isinstance(w_lhs, W_GenericBox) and
isinstance(w_rhs, W_GenericBox) and out is None):
- return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
+ return self.call_scalar(space, w_lhs, w_rhs, casting)
if isinstance(w_lhs, W_GenericBox):
w_lhs = W_NDimArray.from_scalar(space, w_lhs)
assert isinstance(w_lhs, W_NDimArray)
if isinstance(w_rhs, W_GenericBox):
w_rhs = W_NDimArray.from_scalar(space, w_rhs)
assert isinstance(w_rhs, W_NDimArray)
+ calc_dtype, dt_out, func = self.find_specialization(
+ space, w_ldtype, w_rdtype, out, casting, w_lhs, w_rhs)
+
new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
new_shape = shape_agreement(space, new_shape, out,
broadcast_down=False)
w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs)
@@ -637,7 +630,10 @@
w_res = space.call_method(w_highpriority, '__array_wrap__', w_res,
ctxt)
return w_res
- def call_scalar(self, space, w_lhs, w_rhs, in_dtype):
+ def call_scalar(self, space, w_lhs, w_rhs, casting):
+ in_dtype, out_dtype, func = self.find_specialization(
+ space, w_lhs.get_dtype(space), w_rhs.get_dtype(space),
+ out=None, casting=casting)
w_val = self.func(in_dtype,
w_lhs.convert_to(space, in_dtype),
w_rhs.convert_to(space, in_dtype))
@@ -645,7 +641,8 @@
return w_val.w_obj
return w_val
- def _find_specialization(self, space, l_dtype, r_dtype, out, casting):
+ def _find_specialization(self, space, l_dtype, r_dtype, out, casting,
+ w_arg1, w_arg2):
if (not self.allow_bool and (l_dtype.is_bool() or
r_dtype.is_bool()) or
not self.allow_complex and (l_dtype.is_complex() or
@@ -657,15 +654,23 @@
dtype = find_result_type(space, [], [l_dtype, r_dtype])
bool_dtype = get_dtype_cache(space).w_booldtype
return dtype, bool_dtype, self.func
- dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
+ dt_in, dt_out = self._calc_dtype(
+ space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2)
return dt_in, dt_out, self.func
- def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+ def find_specialization(self, space, l_dtype, r_dtype, out, casting,
+ w_arg1=None, w_arg2=None):
if self.simple_binary:
if out is None and not (l_dtype.is_object() or
r_dtype.is_object()):
- dtype = promote_types(space, l_dtype, r_dtype)
+ if w_arg1 is not None and w_arg2 is not None:
+ w_arg1 = convert_to_array(space, w_arg1)
+ w_arg2 = convert_to_array(space, w_arg2)
+ dtype = find_result_type(space, [w_arg1, w_arg2], [])
+ else:
+ dtype = promote_types(space, l_dtype, r_dtype)
return dtype, dtype, self.func
- return self._find_specialization(space, l_dtype, r_dtype, out, casting)
+ return self._find_specialization(
+ space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2)
def find_binop_type(self, space, dtype):
"""Find a valid dtype signature of the form xx->x"""
@@ -686,15 +691,21 @@
"requested type has type code '%s'" % (self.name, dtype.char))
- def _calc_dtype(self, space, l_dtype, r_dtype, out=None, casting='unsafe'):
- use_min_scalar = False
+ def _calc_dtype(self, space, l_dtype, r_dtype, out, casting,
+ w_arg1, w_arg2):
if l_dtype.is_object() or r_dtype.is_object():
dtype = get_dtype_cache(space).w_objectdtype
return dtype, dtype
+ use_min_scalar = (w_arg1 is not None and w_arg2 is not None and
+ ((w_arg1.is_scalar() and not w_arg2.is_scalar()) or
+ (not w_arg1.is_scalar() and w_arg2.is_scalar())))
in_casting = safe_casting_mode(casting)
for dt_in, dt_out in self.dtypes:
if use_min_scalar:
- if not can_cast_array(space, w_arg, dt_in, in_casting):
+ w_arg1 = convert_to_array(space, w_arg1)
+ w_arg2 = convert_to_array(space, w_arg2)
+ if not (can_cast_array(space, w_arg1, dt_in, in_casting) and
+ can_cast_array(space, w_arg2, dt_in, in_casting)):
continue
else:
if not (can_cast_type(space, l_dtype, dt_in, in_casting) and
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit