Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77670:969b168bf034
Date: 2015-05-29 02:21 +0100
http://bitbucket.org/pypy/pypy/changeset/969b168bf034/
Log: Move casting_table and promotion_table to casting.py
diff --git a/pypy/module/micronumpy/casting.py
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -7,8 +7,8 @@
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy import constants as NPY
from .types import (
- Bool, ULong, Long, Float64, Complex64, UnicodeType, VoidType, ObjectType,
- promotion_table)
+ Bool, ULong, Long, Float64, Complex64, StringType, UnicodeType, VoidType,
ObjectType,
+ int_types, float_types, complex_types, number_types, all_types)
from .descriptor import (
get_dtype_cache, as_dtype, is_scalar_w, variable_dtype, new_string_dtype,
new_unicode_dtype, num2dtype)
@@ -324,3 +324,84 @@
elif space.isinstance_w(w_obj, space.w_str):
return variable_dtype(space, 'S%d' % space.len_w(w_obj))
return object_dtype
+
+#_________________________
+
+
+casting_table = [[False] * NPY.NTYPES for _ in range(NPY.NTYPES)]
+
+def enable_cast(type1, type2):
+ casting_table[type1.num][type2.num] = True
+
+def _can_cast(type1, type2):
+ return casting_table[type1.num][type2.num]
+
+for tp in all_types:
+ enable_cast(tp, tp)
+ if tp.num != NPY.DATETIME:
+ enable_cast(Bool, tp)
+ enable_cast(tp, ObjectType)
+ enable_cast(tp, VoidType)
+enable_cast(StringType, UnicodeType)
+#enable_cast(Bool, TimeDelta)
+
+for tp in number_types:
+ enable_cast(tp, StringType)
+ enable_cast(tp, UnicodeType)
+
+for tp1 in int_types:
+ for tp2 in int_types:
+ if tp1.signed:
+ if tp2.signed and tp1.basesize() <= tp2.basesize():
+ enable_cast(tp1, tp2)
+ else:
+ if tp2.signed and tp1.basesize() < tp2.basesize():
+ enable_cast(tp1, tp2)
+ elif not tp2.signed and tp1.basesize() <= tp2.basesize():
+ enable_cast(tp1, tp2)
+for tp1 in int_types:
+ for tp2 in float_types + complex_types:
+ size1 = tp1.basesize()
+ size2 = tp2.basesize()
+ if (size1 < 8 and size2 > size1) or (size1 >= 8 and size2 >= size1):
+ enable_cast(tp1, tp2)
+for tp1 in float_types:
+ for tp2 in float_types + complex_types:
+ if tp1.basesize() <= tp2.basesize():
+ enable_cast(tp1, tp2)
+for tp1 in complex_types:
+ for tp2 in complex_types:
+ if tp1.basesize() <= tp2.basesize():
+ enable_cast(tp1, tp2)
+
+promotion_table = [[-1] * NPY.NTYPES for _ in range(NPY.NTYPES)]
+def promotes(tp1, tp2, tp3):
+ if tp3 is None:
+ num = -1
+ else:
+ num = tp3.num
+ promotion_table[tp1.num][tp2.num] = num
+
+
+for tp in all_types:
+ promotes(tp, ObjectType, ObjectType)
+ promotes(ObjectType, tp, ObjectType)
+
+for tp1 in [Bool] + number_types:
+ for tp2 in [Bool] + number_types:
+ if tp1 is tp2:
+ promotes(tp1, tp1, tp1)
+ elif _can_cast(tp1, tp2):
+ promotes(tp1, tp2, tp2)
+ elif _can_cast(tp2, tp1):
+ promotes(tp1, tp2, tp1)
+ else:
+ # Brute-force search for the least upper bound
+ result = None
+ for tp3 in number_types:
+ if _can_cast(tp1, tp3) and _can_cast(tp2, tp3):
+ if result is None:
+ result = tp3
+ elif _can_cast(tp3, result) and not _can_cast(result, tp3):
+ result = tp3
+ promotes(tp1, tp2, result)
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -156,6 +156,7 @@
def can_cast_to(self, other):
# equivalent to PyArray_CanCastSafely
+ from .casting import casting_table
return casting_table[self.num][other.num]
class Primitive(object):
@@ -2503,85 +2504,9 @@
_setup()
del _setup
-casting_table = [[False] * NPY.NTYPES for _ in range(NPY.NTYPES)]
number_types = int_types + float_types + complex_types
all_types = [Bool] + number_types + [ObjectType, StringType, UnicodeType,
VoidType]
-def enable_cast(type1, type2):
- casting_table[type1.num][type2.num] = True
-
-def _can_cast(type1, type2):
- return casting_table[type1.num][type2.num]
-
-for tp in all_types:
- enable_cast(tp, tp)
- if tp.num != NPY.DATETIME:
- enable_cast(Bool, tp)
- enable_cast(tp, ObjectType)
- enable_cast(tp, VoidType)
-enable_cast(StringType, UnicodeType)
-#enable_cast(Bool, TimeDelta)
-
-for tp in number_types:
- enable_cast(tp, StringType)
- enable_cast(tp, UnicodeType)
-
-for tp1 in int_types:
- for tp2 in int_types:
- if tp1.signed:
- if tp2.signed and tp1.basesize() <= tp2.basesize():
- enable_cast(tp1, tp2)
- else:
- if tp2.signed and tp1.basesize() < tp2.basesize():
- enable_cast(tp1, tp2)
- elif not tp2.signed and tp1.basesize() <= tp2.basesize():
- enable_cast(tp1, tp2)
-for tp1 in int_types:
- for tp2 in float_types + complex_types:
- size1 = tp1.basesize()
- size2 = tp2.basesize()
- if (size1 < 8 and size2 > size1) or (size1 >= 8 and size2 >= size1):
- enable_cast(tp1, tp2)
-for tp1 in float_types:
- for tp2 in float_types + complex_types:
- if tp1.basesize() <= tp2.basesize():
- enable_cast(tp1, tp2)
-for tp1 in complex_types:
- for tp2 in complex_types:
- if tp1.basesize() <= tp2.basesize():
- enable_cast(tp1, tp2)
-
-promotion_table = [[-1] * NPY.NTYPES for _ in range(NPY.NTYPES)]
-def promotes(tp1, tp2, tp3):
- if tp3 is None:
- num = -1
- else:
- num = tp3.num
- promotion_table[tp1.num][tp2.num] = num
-
-
-for tp in all_types:
- promotes(tp, ObjectType, ObjectType)
- promotes(ObjectType, tp, ObjectType)
-
-for tp1 in [Bool] + number_types:
- for tp2 in [Bool] + number_types:
- if tp1 is tp2:
- promotes(tp1, tp1, tp1)
- elif _can_cast(tp1, tp2):
- promotes(tp1, tp2, tp2)
- elif _can_cast(tp2, tp1):
- promotes(tp1, tp2, tp1)
- else:
- # Brute-force search for the least upper bound
- result = None
- for tp3 in number_types:
- if _can_cast(tp1, tp3) and _can_cast(tp2, tp3):
- if result is None:
- result = tp3
- elif _can_cast(tp3, result) and not _can_cast(result, tp3):
- result = tp3
- promotes(tp1, tp2, result)
_int_types = [(Int8, UInt8), (Int16, UInt16), (Int32, UInt32),
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit