Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77458:76dd9c09a0e8
Date: 2015-05-21 20:01 +0100
http://bitbucket.org/pypy/pypy/changeset/76dd9c09a0e8/

Log:    Implement numpy's complicated scalar handling rules in result_type()

diff --git a/pypy/module/micronumpy/casting.py 
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -63,8 +63,7 @@
             kind = kind_ordering[dtype.kind]
             if kind > max_array_kind:
                 max_array_kind = kind
-    #use_min_scalar = bool(arrays_w) and not all_scalars and max_array_kind >= 
max_scalar_kind
-    use_min_scalar = False
+    use_min_scalar = bool(arrays_w) and not all_scalars and max_array_kind >= 
max_scalar_kind
     if not use_min_scalar:
         for w_array in arrays_w:
             if result is None:
@@ -76,6 +75,31 @@
                 result = dtype
             else:
                 result = _promote_types(space, result, dtype)
+    else:
+        small_unsigned = False
+        alt_result = None
+        for w_array in arrays_w:
+            dtype = w_array.get_dtype()
+            small_unsigned_scalar = False
+            if w_array.is_scalar() and dtype.is_number():
+                num, alt_num = w_array.get_scalar_value().min_dtype()
+                small_unsigned_scalar = (num != alt_num)
+                dtype = num2dtype(space, num)
+            if result is None:
+                result = dtype
+                small_unsigned = small_unsigned_scalar
+            else:
+                result, small_unsigned = _promote_types_su(
+                        space, result, dtype,
+                        small_unsigned, small_unsigned_scalar)
+        for dtype in dtypes_w:
+            if result is None:
+                result = dtype
+                small_unsigned = False
+            else:
+                result, small_unsigned = _promote_types_su(
+                        space, result, dtype,
+                        small_unsigned, False)
     return result
 
 
@@ -215,6 +239,27 @@
             return dt1
     raise oefmt(space.w_TypeError, "invalid type promotion")
 
+def _promote_types_su(space, dt1, dt2, su1, su2):
+    """Like _promote_types(), but handles the small_unsigned flag as well"""
+    if su1:
+        if dt2.is_bool() or dt2.is_unsigned():
+            dt1 = dt1.as_unsigned(space)
+        else:
+            dt1 = dt1.as_signed(space)
+    elif su2:
+        if dt1.is_bool() or dt1.is_unsigned():
+            dt2 = dt2.as_unsigned(space)
+        else:
+            dt2 = dt2.as_signed(space)
+    if dt1.elsize < dt2.elsize:
+        su = su2 and (su1 or not dt1.is_signed())
+    elif dt1.elsize == dt2.elsize:
+        su = su1 and su2
+    else:
+        su = su1 and (su2 or not dt2.is_signed())
+    return _promote_types(space, dt1, dt2), su
+
+
 
 def find_dtype_for_scalar(space, w_obj, current_guess=None):
     from .boxes import W_GenericBox
diff --git a/pypy/module/micronumpy/descriptor.py 
b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -164,6 +164,20 @@
     def is_native(self):
         return self.byteorder in (NPY.NATIVE, NPY.NATBYTE)
 
+    def as_signed(self, space):
+        """Convert from an unsigned integer dtype to its signed partner"""
+        if self.is_unsigned():
+            return num2dtype(space, self.num - 1)
+        else:
+            return self
+
+    def as_unsigned(self, space):
+        """Convert from a signed integer dtype to its unsigned partner"""
+        if self.is_signed():
+            return num2dtype(space, self.num + 1)
+        else:
+            return self
+
     def get_float_dtype(self, space):
         assert self.is_complex()
         dtype = get_dtype_cache(space).component_dtypes[self.num]
diff --git a/pypy/module/micronumpy/test/test_casting.py 
b/pypy/module/micronumpy/test/test_casting.py
--- a/pypy/module/micronumpy/test/test_casting.py
+++ b/pypy/module/micronumpy/test/test_casting.py
@@ -1,7 +1,8 @@
 from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
-from pypy.module.micronumpy.descriptor import get_dtype_cache
+from pypy.module.micronumpy.descriptor import get_dtype_cache, num2dtype
 from pypy.module.micronumpy.casting import (
-    find_binop_result_dtype, can_cast_type)
+    find_binop_result_dtype, can_cast_type, _promote_types_su)
+import pypy.module.micronumpy.constants as NPY
 
 
 class AppTestNumSupport(BaseNumpyAppTest):
@@ -140,6 +141,20 @@
     assert can_cast_type(space, dt_bool, dt_bool, 'same_kind')
     assert can_cast_type(space, dt_bool, dt_bool, 'unsafe')
 
+def test_promote_types_su(space):
+    dt_int8 = num2dtype(space, NPY.BYTE)
+    dt_uint8 = num2dtype(space, NPY.UBYTE)
+    dt_int16 = num2dtype(space, NPY.SHORT)
+    dt_uint16 = num2dtype(space, NPY.USHORT)
+    # The results must be signed
+    assert _promote_types_su(space, dt_int8, dt_int16, False, False) == 
(dt_int16, False)
+    assert _promote_types_su(space, dt_int8, dt_int16, True, False) == 
(dt_int16, False)
+    assert _promote_types_su(space, dt_int8, dt_int16, False, True) == 
(dt_int16, False)
+
+    # The results may be unsigned
+    assert _promote_types_su(space, dt_int8, dt_int16, True, True) == 
(dt_int16, True)
+    assert _promote_types_su(space, dt_uint8, dt_int16, False, True) == 
(dt_uint16, True)
+
 
 class TestCoercion(object):
     def test_binops(self, space):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to