Author: Justin Peel <notmuchtot...@gmail.com> Branch: numpy-dtype Changeset: r46226:959ca3a44df9 Date: 2011-08-02 23:30 -0600 http://bitbucket.org/pypy/pypy/changeset/959ca3a44df9/
Log: added find_result_dtype. binops should work correctly now. diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py --- a/pypy/module/micronumpy/interp_dtype.py +++ b/pypy/module/micronumpy/interp_dtype.py @@ -46,6 +46,8 @@ UNSIGNEDLTR = 'u' COMPLEXLTR = 'c' +kind_dict = {'b': 0, 'u': 1, 'i': 1, 'f': 2, 'c': 2} + class Dtype(Wrappable): # attributes: type, kind, typeobj?(I think it should point to np.float64 or # the like), byteorder, flags, type_num, elsize, alignment, subarray, @@ -174,14 +176,41 @@ raise OperationError(space.w_TypeError, space.wrap("data type not understood")) -def find_base_dtype(dtype1, dtype2): +def find_result_dtype(d1, d2): + # this function is for determining the result dtype of bin ops, etc. + # it is kind of a mess so feel free to improve it + + # first make sure larger num is in d2 + if d1.num > d2.num: + dtype1 = d2 + dtype2 = d1 + else: + dtype1 = d1 + dtype2 = d2 num1 = dtype1.num num2 = dtype2.num - # this is much more complex - if num1 < num2: + kind1 = dtype1.kind + kind2 = dtype2.kind + if kind1 == kind2: + # dtype2 has the greater number return dtype2 - return dtype - + kind_num1 = kind_dict[kind1] + kind_num2 = kind_dict[kind2] + if kind_num1 == kind_num2: # two kinds of integers or float and complex + # XXX: Need to deal with float and complex combo here also + if kind2 == SIGNEDLTR: + return dtype2 + if num2 < UInt32_num: + return _dtype_list[num2+1] + if num2 == UInt64_num or (LONG_BIT == 64 and num2 == Long_num): # UInt64 + return Float64_dtype + # dtype2 is uint32 + return Int64_dtype + if kind_num1 == 1: # is an integer + if num2 == Float32_num and num2 == UInt64_num or \ + (LONG_BIT == 64 and num2 == Long_num): + return Float64_dtype + return dtype2 def descr_new_dtype(space, w_type, w_string_or_type): return space.wrap(get_dtype(space, w_type, w_string_or_type)) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -2,7 +2,7 @@ from pypy.interpreter.error import OperationError, operationerrfmt from pypy.interpreter.gateway import interp2app, unwrap_spec from pypy.interpreter.typedef import TypeDef, GetSetProperty -from pypy.module.micronumpy.interp_dtype import Dtype, Float64_num, Int32_num, Float64_dtype, get_dtype, find_scalar_dtype, find_base_dtype +from pypy.module.micronumpy.interp_dtype import Dtype, Float64_num, Int32_num, Float64_dtype, get_dtype, find_scalar_dtype, find_result_dtype from pypy.module.micronumpy.interp_support import Signature from pypy.module.micronumpy import interp_ufuncs from pypy.objspace.std.floatobject import float2string as float2string_orig @@ -417,16 +417,24 @@ def __init__(self, function, left, right, signature): VirtualArray.__init__(self, signature) - self.function = function self.left = left self.right = right dtype = self.left.find_dtype() dtype2 = self.right.find_dtype() - # this is more complicated than this. - # for instance int32 + uint32 = int64 - if dtype.num != dtype.num: - dtype = find_base_dtype(dtype, dtype2) - self.dtype = dtype + if dtype.num != dtype2.num: + newdtype = find_result_dtype(dtype, dtype2) + cast = newdtype.cast + if dtype.num != newdtype.num: + if dtype2.num != newdtype.num: + self.function = lambda x, y: function(cast(x), cast(y)) + else: + self.function = lambda x, y: function(cast(x), y) + else: + self.function = lambda x, y: function(x, cast(y)) + self.dtype = newdtype + else: + self.dtype = dtype + self.function = function def _del_sources(self): self.left = None diff --git a/pypy/module/micronumpy/test/test_dtypes.py b/pypy/module/micronumpy/test/test_dtypes.py --- a/pypy/module/micronumpy/test/test_dtypes.py +++ b/pypy/module/micronumpy/test/test_dtypes.py @@ -35,3 +35,13 @@ assert a[0] == 1 assert a[1] == 2 assert a[2] == 3 + + def test_bool_binop_types(self): + from numpy import array, dtype + types = ('?','b','B','h','H','i','I','l','L','q','Q','f','d','g') + dtypes = [dtype(t) for t in types] + N = len(types) + a = array([True],'?') + for i in xrange(N): + assert (a + array([0], types[i])).dtype is dtypes[i] +# need more tests for binop result types _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit