Author: Justin Peel <notmuchtot...@gmail.com>
Branch: numpy-dtype
Changeset: r46226:959ca3a44df9
Date: 2011-08-02 23:30 -0600
http://bitbucket.org/pypy/pypy/changeset/959ca3a44df9/

Log:    added find_result_dtype. binops should work correctly now.

diff --git a/pypy/module/micronumpy/interp_dtype.py 
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -46,6 +46,8 @@
 UNSIGNEDLTR = 'u'
 COMPLEXLTR = 'c'
 
+kind_dict = {'b': 0, 'u': 1, 'i': 1, 'f': 2, 'c': 2}
+
 class Dtype(Wrappable):
     # attributes: type, kind, typeobj?(I think it should point to np.float64 or
     # the like), byteorder, flags, type_num, elsize, alignment, subarray,
@@ -174,14 +176,41 @@
         raise OperationError(space.w_TypeError,
                             space.wrap("data type not understood"))
 
-def find_base_dtype(dtype1, dtype2):
+def find_result_dtype(d1, d2):
+    # this function is for determining the result dtype of bin ops, etc.
+    # it is kind of a mess so feel free to improve it
+
+    # first make sure larger num is in d2
+    if d1.num > d2.num:
+        dtype1 = d2
+        dtype2 = d1
+    else:
+        dtype1 = d1
+        dtype2 = d2
     num1 = dtype1.num
     num2 = dtype2.num
-    # this is much more complex
-    if num1 < num2:
+    kind1 = dtype1.kind
+    kind2 = dtype2.kind
+    if kind1 == kind2:
+        # dtype2 has the greater number
         return dtype2
-    return dtype
-
+    kind_num1 = kind_dict[kind1]
+    kind_num2 = kind_dict[kind2]
+    if kind_num1 == kind_num2: # two kinds of integers or float and complex
+        # XXX: Need to deal with float and complex combo here also
+        if kind2 == SIGNEDLTR:
+            return dtype2
+        if num2 < UInt32_num:
+            return _dtype_list[num2+1]
+        if num2 == UInt64_num or (LONG_BIT == 64 and num2 == Long_num): # 
UInt64
+            return Float64_dtype
+        # dtype2 is uint32
+        return Int64_dtype
+    if kind_num1 == 1: # is an integer
+        if num2 == Float32_num and num2 == UInt64_num or \
+                (LONG_BIT == 64 and num2 == Long_num):
+            return Float64_dtype
+    return dtype2
 
 def descr_new_dtype(space, w_type, w_string_or_type):
     return space.wrap(get_dtype(space, w_type, w_string_or_type))
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -2,7 +2,7 @@
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.interpreter.gateway import interp2app, unwrap_spec
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
-from pypy.module.micronumpy.interp_dtype import Dtype, Float64_num, Int32_num, 
Float64_dtype, get_dtype, find_scalar_dtype, find_base_dtype
+from pypy.module.micronumpy.interp_dtype import Dtype, Float64_num, Int32_num, 
Float64_dtype, get_dtype, find_scalar_dtype, find_result_dtype
 from pypy.module.micronumpy.interp_support import Signature
 from pypy.module.micronumpy import interp_ufuncs
 from pypy.objspace.std.floatobject import float2string as float2string_orig
@@ -417,16 +417,24 @@
 
     def __init__(self, function, left, right, signature):
         VirtualArray.__init__(self, signature)
-        self.function = function
         self.left = left
         self.right = right
         dtype = self.left.find_dtype()
         dtype2 = self.right.find_dtype()
-        # this is more complicated than this.
-        # for instance int32 + uint32 = int64
-        if dtype.num != dtype.num:
-            dtype = find_base_dtype(dtype, dtype2)
-        self.dtype = dtype
+        if dtype.num != dtype2.num:
+            newdtype = find_result_dtype(dtype, dtype2)
+            cast = newdtype.cast
+            if dtype.num != newdtype.num:
+                if dtype2.num != newdtype.num:
+                    self.function = lambda x, y: function(cast(x), cast(y))
+                else:
+                    self.function = lambda x, y: function(cast(x), y)
+            else:
+                self.function = lambda x, y: function(x, cast(y))
+            self.dtype = newdtype
+        else:
+            self.dtype = dtype
+            self.function = function
 
     def _del_sources(self):
         self.left = None
diff --git a/pypy/module/micronumpy/test/test_dtypes.py 
b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -35,3 +35,13 @@
         assert a[0] == 1
         assert a[1] == 2
         assert a[2] == 3
+
+    def test_bool_binop_types(self):
+        from numpy import array, dtype
+        types = ('?','b','B','h','H','i','I','l','L','q','Q','f','d','g')
+        dtypes = [dtype(t) for t in types]
+        N = len(types)
+        a = array([True],'?')
+        for i in xrange(N):
+            assert (a + array([0], types[i])).dtype is dtypes[i]
+# need more tests for binop result types
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to