Author: Alex Gaynor <alex.gay...@gmail.com> Branch: numpy-dtype-alt Changeset: r46706:5377b6e0918b Date: 2011-08-22 12:41 -0500 http://bitbucket.org/pypy/pypy/changeset/5377b6e0918b/
Log: fix for sum/prod with various dtypes. breaks test_zjit. diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py --- a/pypy/module/micronumpy/interp_dtype.py +++ b/pypy/module/micronumpy/interp_dtype.py @@ -218,6 +218,10 @@ class IntegerArithmeticDtype(object): _mixin_ = True + @binop + def add(self, v1, v2): + return v1 + v2 + def str_format(self, item): return str(widen(self.unbox(item))) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -92,13 +92,19 @@ reduce_driver.jit_merge_point(signature=self.signature, self=self, res_dtype=res_dtype, size=size, i=i, result=result) - result = getattr(res_dtype, op_name)(result, self.eval(i)) + result = getattr(res_dtype, op_name)( + result, + self.eval(i).convert_to(res_dtype) + ) i += 1 return result def impl(self, space): - result = space.fromcache(interp_dtype.W_Float64Dtype).box(init).convert_to(self.find_dtype()) - return loop(self, self.find_dtype(), result, self.find_size()).wrap(space) + dtype = interp_ufuncs.find_unaryop_result_dtype( + space, self.find_dtype(), promote_to_largest=True + ) + result = dtype.adapt_val(init) + return loop(self, dtype, result, self.find_size()).wrap(space) return func_with_new_name(impl, "reduce_%s_impl" % op_name) def _reduce_max_min_impl(op_name): @@ -178,8 +184,8 @@ def descr_any(self, space): return space.wrap(self._any()) - descr_sum = _reduce_sum_prod_impl("add", 0.0) - descr_prod = _reduce_sum_prod_impl("mul", 1.0) + descr_sum = _reduce_sum_prod_impl("add", 0) + descr_prod = _reduce_sum_prod_impl("mul", 1) descr_max = _reduce_max_min_impl("max") descr_min = _reduce_max_min_impl("min") descr_argmax = _reduce_argmax_argmin_impl("max") diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -73,11 +73,19 @@ assert False -def find_unaryop_result_dtype(space, dt, promote_to_float=False): +def find_unaryop_result_dtype(space, dt, promote_to_float=False, + promote_to_largest=False): if promote_to_float: for bytes, dtype in interp_dtype.dtypes_by_num_bytes: if dtype.kind == interp_dtype.FLOATINGLTR and dtype.num_bytes >= dt.num_bytes: return space.fromcache(dtype) + if promote_to_largest: + if dt.kind == interp_dtype.BOOLLTR or dt.kind == interp_dtype.SIGNEDLTR: + return space.fromcache(interp_dtype.W_Int64Dtype) + elif dt.kind == interp_dtype.FLOATINGLTR: + return space.fromcache(interp_dtype.W_Float64Dtype) + else: + assert False return dt diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -412,6 +412,9 @@ assert a.sum() == 10.0 assert a[:4].sum() == 6.0 + a = array([True] * 5, bool) + assert a.sum() == 5 + def test_prod(self): from numpy import array a = array(range(1,6)) diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py --- a/pypy/module/micronumpy/test/test_zjit.py +++ b/pypy/module/micronumpy/test/test_zjit.py @@ -14,7 +14,7 @@ class TestNumpyJIt(LLJitMixin): def setup_class(cls): cls.space = FakeSpace() - cls.float64_dtype = W_Float64Dtype(cls.space) + cls.float64_dtype = cls.space.fromcache(W_Float64Dtype) def test_add(self): def f(i): _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit