Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77418:c9ba53952909
Date: 2015-05-20 03:56 +0100
http://bitbucket.org/pypy/pypy/changeset/c9ba53952909/

Log:    fix promote_types(<numeric>, <string>)

diff --git a/pypy/module/micronumpy/casting.py 
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -10,7 +10,9 @@
 from .types import (
     Bool, ULong, Long, Float64, Complex64, UnicodeType, VoidType, ObjectType,
     promotion_table)
-from .descriptor import get_dtype_cache, as_dtype, is_scalar_w, variable_dtype
+from .descriptor import (
+    get_dtype_cache, as_dtype, is_scalar_w, variable_dtype, new_string_dtype,
+    new_unicode_dtype)
 
 @jit.unroll_safe
 def result_type(space, __args__):
@@ -151,39 +153,41 @@
     if dt1.num > dt2.num:
         dt1, dt2 = dt2, dt1
 
-    # for now this means mixing signed and unsigned
-    if dt2.kind == NPY.SIGNEDLTR:
-        # if dt2 has a greater number of bytes, then just go with it
-        if dt1.itemtype.get_element_size() < dt2.itemtype.get_element_size():
-            return dt2
-        # we need to promote both dtypes
-        dtypenum = dt2.num + 2
-    elif dt2.num == NPY.ULONGLONG or (LONG_BIT == 64 and dt2.num == NPY.ULONG):
-        # UInt64 + signed = Float64
-        dtypenum = NPY.DOUBLE
-    elif dt2.is_flexible():
-        # For those operations that get here (concatenate, stack),
-        # flexible types take precedence over numeric type
-        if dt2.is_record():
-            return dt2
-        if dt1.is_str_or_unicode():
-            if dt2.elsize >= dt1.elsize:
+    if dt2.is_str():
+        if dt1.is_str():
+            if dt1.elsize > dt2.elsize:
+                return dt1
+            else:
                 return dt2
+        else:  # dt1 is numeric
+            dt1_size = dt1.itemtype.strlen
+            if dt1_size > dt2.elsize:
+                return new_string_dtype(space, dt1_size)
+            else:
+                return dt2
+    elif dt2.is_unicode():
+        if dt1.is_unicode():
+            if dt1.elsize > dt2.elsize:
+                return dt1
+            else:
+                return dt2
+        elif dt1.is_str():
+            if dt2.elsize >= 4 * dt1.elsize:
+                return dt2
+            else:
+                return new_unicode_dtype(space, 4 * dt1.elsize)
+        else:  # dt1 is numeric
+            dt1_size = 4 * dt1.itemtype.strlen
+            if dt1_size > dt2.elsize:
+                return new_unicode_dtype(space, dt1_size)
+            else:
+                return dt2
+    else:
+        assert dt2.num == NPY.VOID
+        if can_cast_type(space, dt1, dt2, casting='equiv'):
             return dt1
-        return dt2
-    else:
-        # increase to the next signed type
-        dtypenum = dt2.num + 1
-    newdtype = get_dtype_cache(space).dtypes_by_num[dtypenum]
+    raise oefmt(space.w_TypeError, "invalid type promotion")
 
-    if (newdtype.itemtype.get_element_size() > dt2.itemtype.get_element_size() 
or
-            newdtype.kind == NPY.FLOATINGLTR):
-        return newdtype
-    else:
-        # we only promoted to long on 32-bit or to longlong on 64-bit
-        # this is really for dealing with the Long and Ulong dtypes
-        dtypenum += 2
-        return get_dtype_cache(space).dtypes_by_num[dtypenum]
 
 def find_dtype_for_scalar(space, w_obj, current_guess=None):
     from .boxes import W_GenericBox
diff --git a/pypy/module/micronumpy/descriptor.py 
b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -146,6 +146,9 @@
     def is_str(self):
         return self.num == NPY.STRING
 
+    def is_unicode(self):
+        return self.num == NPY.UNICODE
+
     def is_object(self):
         return self.num == NPY.OBJECT
 
diff --git a/pypy/module/micronumpy/test/test_casting.py 
b/pypy/module/micronumpy/test/test_casting.py
--- a/pypy/module/micronumpy/test/test_casting.py
+++ b/pypy/module/micronumpy/test/test_casting.py
@@ -130,6 +130,7 @@
         assert np.promote_types('i8', 'f4') == np.dtype('float64')
         assert np.promote_types('>i8', '<c8') == np.dtype('complex128')
         assert np.promote_types('i4', 'S8') == np.dtype('S11')
+        assert np.promote_types('f4', 'S8') == np.dtype('S32')
 
 def test_can_cast_same_type(space):
     dt_bool = get_dtype_cache(space).w_booldtype
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -721,6 +721,7 @@
 
 class Float(Primitive):
     _mixin_ = True
+    strlen = 32
 
     def _coerce(self, space, w_item):
         if w_item is None:
@@ -1047,7 +1048,7 @@
         else:
             return x
 
-class Float16(BaseType, Float):
+class Float16(Float, BaseType):
     _STORAGE_T = rffi.USHORT
     T = rffi.SHORT
     num = NPY.HALF
@@ -1092,7 +1093,7 @@
             hbits = byteswap(hbits)
         raw_storage_setitem_unaligned(storage, i + offset, hbits)
 
-class Float32(BaseType, Float):
+class Float32(Float, BaseType):
     T = rffi.FLOAT
     num = NPY.FLOAT
     kind = NPY.FLOATINGLTR
@@ -1101,7 +1102,7 @@
     format_code = "f"
     max_value = 3.4e38
 
-class Float64(BaseType, Float):
+class Float64(Float, BaseType):
     T = rffi.DOUBLE
     num = NPY.DOUBLE
     kind = NPY.FLOATINGLTR
@@ -1112,6 +1113,7 @@
 
 class ComplexFloating(object):
     _mixin_ = True
+    strlen = 64
 
     def _coerce(self, space, w_item):
         if w_item is None:
@@ -1721,7 +1723,7 @@
     ComponentType = Float64
 
 if boxes.long_double_size == 8:
-    class FloatLong(BaseType, Float):
+    class FloatLong(Float, BaseType):
         T = rffi.DOUBLE
         num = NPY.LONGDOUBLE
         kind = NPY.FLOATINGLTR
@@ -1739,7 +1741,7 @@
         ComponentType = FloatLong
 
 elif boxes.long_double_size in (12, 16):
-    class FloatLong(BaseType, Float):
+    class FloatLong(Float, BaseType):
         T = rffi.LONGDOUBLE
         num = NPY.LONGDOUBLE
         kind = NPY.FLOATINGLTR
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to