Author: Ronan Lamy <[email protected]>
Branch: use_min_scalar
Changeset: r77776:02c9c753b06c
Date: 2015-06-02 05:31 +0100
http://bitbucket.org/pypy/pypy/changeset/02c9c753b06c/
Log: correct handling of scalars for non-simple binary ufuncs
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1351,3 +1351,4 @@
assert np.add(np.float16(0), np.complex128(0)).dtype == np.complex128
assert np.add(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
assert np.subtract(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
+ assert np.divide(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -21,7 +21,8 @@
from pypy.module.micronumpy.support import (_parse_signature, product,
get_storage_as_int, is_rhs_priority_higher)
from .casting import (
- can_cast_type, can_cast_to, find_result_type, promote_types)
+ can_cast_type, can_cast_array, can_cast_to,
+ find_result_type, promote_types)
from .boxes import W_GenericBox, W_ObjectBox
def done_if_true(dtype, val):
@@ -642,12 +643,6 @@
def _find_specialization(self, space, l_dtype, r_dtype, out, casting,
w_arg1, w_arg2):
- if (self.are_common_types(l_dtype, r_dtype) and
- w_arg1 is not None and w_arg2 is not None):
- if not w_arg1.is_scalar() and w_arg2.is_scalar():
- r_dtype = l_dtype
- elif w_arg1.is_scalar() and not w_arg2.is_scalar():
- l_dtype = r_dtype
if (not self.allow_bool and (l_dtype.is_bool() or
r_dtype.is_bool()) or
not self.allow_complex and (l_dtype.is_complex() or
@@ -659,7 +654,8 @@
dtype = find_result_type(space, [], [l_dtype, r_dtype])
bool_dtype = get_dtype_cache(space).w_booldtype
return dtype, bool_dtype, self.func
- dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
+ dt_in, dt_out = self._calc_dtype(
+ space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2)
return dt_in, dt_out, self.func
def find_specialization(self, space, l_dtype, r_dtype, out, casting,
@@ -695,15 +691,21 @@
"requested type has type code '%s'" % (self.name, dtype.char))
- def _calc_dtype(self, space, l_dtype, r_dtype, out=None, casting='unsafe'):
- use_min_scalar = False
+ def _calc_dtype(self, space, l_dtype, r_dtype, out, casting,
+ w_arg1, w_arg2):
if l_dtype.is_object() or r_dtype.is_object():
dtype = get_dtype_cache(space).w_objectdtype
return dtype, dtype
+ use_min_scalar = (w_arg1 is not None and w_arg2 is not None and
+ ((w_arg1.is_scalar() and not w_arg2.is_scalar()) or
+ (not w_arg1.is_scalar() and w_arg2.is_scalar())))
in_casting = safe_casting_mode(casting)
for dt_in, dt_out in self.dtypes:
if use_min_scalar:
- if not can_cast_array(space, w_arg, dt_in, in_casting):
+ w_arg1 = convert_to_array(space, w_arg1)
+ w_arg2 = convert_to_array(space, w_arg2)
+ if not (can_cast_array(space, w_arg1, dt_in, in_casting) and
+ can_cast_array(space, w_arg2, dt_in, in_casting)):
continue
else:
if not (can_cast_type(space, l_dtype, dt_in, in_casting) and
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit